Skip to content

Commit 2507b87

Browse files
[AWQ] modernized typehints, allow for n_grid/duo_scaling parameters (#2003)
SUMMARY: This updates the AWQModifier type hints to python 3.10+, exposes n_grid as a parameter rather than the hard-coded 20, and allows the value of duo_scaling to be None, in which case half of the grid search is performed with duo_scaling off and the other half with duo_scaling on. This stems from @fynnsu 's findings that the previous default value of True was actually leading to worse behavior in some circumstances TEST PLAN: Run benchmarks with different values of n_grid and duo_scaling to see how much this affects accuracy/runtime --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Fynn Schmitt-Ulms <fynnsu@outlook.com>
1 parent f3a8e78 commit 2507b87

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
2-
from typing import Dict, List, Optional, Tuple, Union
2+
from itertools import product
3+
from typing import Literal
34

45
import torch
56
from compressed_tensors.quantization import disable_quantization
@@ -111,31 +112,40 @@ class AWQModifier(Modifier, QuantizationMixin):
111112
device. Defaults to None, so cached args are not offloaded. Consider setting
112113
to torch.device("cpu") if you are encountering OOM errors
113114
:param duo_scaling: whether to use duo scaling, which uses both input activations
114-
and weights to determine the scaling factor
115+
and weights to determine the scaling factor. Defaults to True
116+
If True, both activations and weights are used.
117+
If False, only activations are used.
118+
If "both", half the grid search is performed with duo_scaling=False and the
119+
other half is performed with duo_scaling=True.
120+
:param n_grid: when performing the best scales grid search for each mapping,
121+
this specifies how many grid points should be used. To decrease the runtime,
122+
at the possible cost of slightly worse scales, this can be decreased.
123+
Defaults to 20
115124
"""
116125

117126
# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
118127
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
119128

120129
# User-provided vars (in addition to QuantizationMixin args)
121-
sequential_targets: Union[str, List[str], None] = None
122-
mappings: Optional[List[AWQMapping]] = None
123-
offload_device: Optional[torch.device] = None
124-
duo_scaling: bool = True
130+
sequential_targets: str | list[str] | None = None
131+
mappings: list[AWQMapping] | None = None
132+
offload_device: torch.device | None = None
133+
duo_scaling: bool | Literal["both"] = True
134+
n_grid: int = 20
125135

126136
# Private vars set during validation
127-
_num_bits: Optional[int] = PrivateAttr(default=None)
128-
_symmetric: Optional[bool] = PrivateAttr(default=None)
129-
_group_size: Optional[int] = PrivateAttr(default=None)
137+
_num_bits: int | None = PrivateAttr(default=None)
138+
_symmetric: bool | None = PrivateAttr(default=None)
139+
_group_size: int | None = PrivateAttr(default=None)
130140

131141
# Private vars set during initialization, cleared during finalization
132-
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
142+
_resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list)
133143
# Cache list of forward input args for each parent module, one dict for each batch
134-
_parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
144+
_parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr(
135145
default_factory=dict
136146
)
137147
# Dict[smooth layer name, (activation means, activation counts)]
138-
_smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
148+
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
139149
default_factory=dict
140150
)
141151

@@ -389,7 +399,7 @@ def _setup_activation_cache_hooks(self) -> None:
389399

390400
def cache_parent_kwargs_hook(
391401
module: torch.nn.Module,
392-
args: Tuple[torch.Tensor, ...],
402+
args: tuple[torch.Tensor, ...],
393403
kwargs,
394404
):
395405
values = inspect.signature(module.forward).bind(*args, **kwargs)
@@ -398,7 +408,7 @@ def cache_parent_kwargs_hook(
398408
def create_cache_smooth_activations_hook_fn(smooth_name):
399409
def cache_smooth_activations_hook(
400410
_module: torch.nn.Module,
401-
args: Tuple[torch.Tensor, ...],
411+
args: tuple[torch.Tensor, ...],
402412
_output: torch.Tensor,
403413
):
404414
self._smooth_activation_means[smooth_name] = _accumulate_mean(
@@ -454,9 +464,11 @@ def _apply_smoothing(self, model: Module) -> None:
454464
balance_layers = mapping.balance_layers
455465
parent_module = mapping.parent
456466

457-
with align_modules(
458-
[parent_module, smooth_layer, *balance_layers]
459-
), calibration_forward_context(model), HooksMixin.disable_hooks():
467+
with (
468+
align_modules([parent_module, smooth_layer, *balance_layers]),
469+
calibration_forward_context(model),
470+
HooksMixin.disable_hooks(),
471+
):
460472
# [STEP 1]: Compute per-channel mean of normalised weights
461473
# All layer weights are concatted together
462474
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
@@ -559,13 +571,13 @@ def _smooth(module):
559571
v.batch_intermediates.clear()
560572
self._assert_all_activations_consumed()
561573

562-
def _run_samples(self, module: Module) -> List[torch.Tensor]:
574+
def _run_samples(self, module: Module) -> list[torch.Tensor]:
563575
outputs = [
564576
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
565577
]
566578
return [
567579
# If Tuple, assume that first argument is the input
568-
output[0] if isinstance(output, Tuple) else output
580+
output[0] if isinstance(output, tuple) else output
569581
for output in outputs
570582
]
571583

@@ -574,8 +586,8 @@ def _compute_best_scale(
574586
x_mean: torch.Tensor,
575587
w_mean: torch.Tensor,
576588
parent_module: torch.nn.Module,
577-
linears2scale: List[torch.nn.Linear],
578-
fp16_outputs: List[torch.Tensor],
589+
linears2scale: list[torch.nn.Linear],
590+
fp16_outputs: list[torch.Tensor],
579591
) -> torch.Tensor:
580592
"""
581593
Compute loss and select best scales
@@ -586,7 +598,6 @@ def _compute_best_scale(
586598
W: original weights in FP16 | layer
587599
s: per channel scaling factor | s^-1 * X
588600
"""
589-
n_grid = 20
590601
history = []
591602
best_ratio = -1
592603
best_scales = None
@@ -602,12 +613,21 @@ def _compute_best_scale(
602613
x_mean = x_mean.view(-1).to(device)
603614
w_mean = w_mean.view(-1).to(device)
604615

605-
for ratio in range(n_grid):
616+
match self.duo_scaling:
617+
# if self.duo_scaling is "both", perform half the grid search with
618+
# duo_scaling off and half with duo_scaling on
619+
case "both":
620+
n_grid = int(self.n_grid / 2)
621+
duo_scalings = [False, True]
622+
case _:
623+
n_grid = self.n_grid
624+
duo_scalings = [self.duo_scaling]
625+
for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
606626
# create new scales
607-
ratio = ratio / n_grid
627+
ratio = grid_idx / n_grid
608628

609629
# NOTE: s^-1 * x is fused here, according to paper
610-
if self.duo_scaling:
630+
if use_duo_scaling:
611631
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
612632
min=1e-4
613633
)
@@ -667,8 +687,8 @@ def _compute_best_scale(
667687
@torch.no_grad()
668688
def _compute_loss(
669689
self,
670-
fp16_outputs: List[torch.Tensor],
671-
int_w_outputs: List[torch.Tensor],
690+
fp16_outputs: list[torch.Tensor],
691+
int_w_outputs: list[torch.Tensor],
672692
device: torch.device,
673693
) -> torch.Tensor:
674694
loss = 0.0
@@ -746,8 +766,8 @@ def _pseudo_quantize_tensor(
746766

747767
def _accumulate_mean(
748768
inp: torch.Tensor,
749-
prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]],
750-
) -> Tuple[torch.FloatTensor, int]:
769+
prev_mean_and_count: tuple[torch.FloatTensor, int] | None,
770+
) -> tuple[torch.FloatTensor, int]:
751771
sum_added = inp.sum(dim=0)
752772
num_added = inp.size(0)
753773
if prev_mean_and_count is None:
@@ -761,7 +781,7 @@ def _accumulate_mean(
761781
return (prev_sum + sum_added) / new_count, new_count
762782

763783

764-
def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
784+
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
765785
"""
766786
Given a list of names, returns the lowest-scope common parent.
767787

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
from typing import Dict, List, Optional
32

43
from loguru import logger
54
from torch.nn import Module
@@ -143,7 +142,7 @@ class AWQMapping:
143142
# ["re:.*dense$"]
144143
# ),
145144
]
146-
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
145+
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
147146
"BloomForCausalLM": _bloom_mappings,
148147
"CohereForCausalLM": _cohere_mappings,
149148
"Cohere2ForCausalLM": _cohere_mappings,
@@ -186,13 +185,13 @@ class ResolvedMapping:
186185

187186
smooth_name: str
188187
smooth_layer: Module
189-
balance_layers: List[Module]
190-
balance_names: Optional[List[str]] = None
191-
parent: Optional[Module] = None
192-
parent_name: Optional[str] = None
188+
balance_layers: list[Module]
189+
balance_names: list[str]
190+
parent: Module
191+
parent_name: str
193192

194193

195-
def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
194+
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
196195
"""
197196
:param architecture: str: The architecture of the model
198197
:return: list: The layer mappings for the given architecture

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def test_validate():
171171
}
172172
)
173173

174+
AWQModifier(scheme="W4A16", duo_scaling="both")
175+
with pytest.raises(ValidationError):
176+
AWQModifier(scheme="W4A16", duo_scaling="Both")
177+
with pytest.raises(ValidationError):
178+
AWQModifier(scheme="W4A16", duo_scaling="x")
179+
174180

175181
@pytest.mark.unit
176182
def test_get_lowest_common_parent():

0 commit comments

Comments
 (0)