Skip to content

Commit f94df6f

Browse files
modernized typehints, None duo scaling
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 1c85a66 commit f94df6f

File tree

2 files changed

+87
-72
lines changed

2 files changed

+87
-72
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 81 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
from typing import Dict, List, Optional, Tuple, Union
32

43
import torch
54
from compressed_tensors.quantization import disable_quantization
@@ -111,31 +110,40 @@ class AWQModifier(Modifier, QuantizationMixin):
111110
device. Defaults to None, so cached args are not offloaded. Consider setting
112111
to torch.device("cpu") if you are encountering OOM errors
113112
:param duo_scaling: whether to use duo scaling, which uses both input activations
114-
and weights to determine the scaling factor
113+
and weights to determine the scaling factor. Defaults to None
114+
If False, only activations are used.
115+
If True, both activations and weights are used.
116+
If None, half the grid search is performed with duo_scaling=False and the
117+
other half is performed with duo_scaling=True.
118+
:param n_grid: when performing the best scales grid search for each mapping,
119+
this specifies how many grid points should be used. To decrease the runtime,
120+
at the possible cost of slightly worse scales, this can be decreased.
121+
Defaults to 20
115122
"""
116123

117124
# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
118125
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
119126

120127
# User-provided vars (in addition to QuantizationMixin args)
121-
sequential_targets: Union[str, List[str], None] = None
122-
mappings: Optional[List[AWQMapping]] = None
123-
offload_device: Optional[torch.device] = None
124-
duo_scaling: bool = True
128+
sequential_targets: str | list[str] | None = None
129+
mappings: list[AWQMapping] | None = None
130+
offload_device: torch.device | None = None
131+
duo_scaling: bool | None = None
132+
n_grid: int = 20
125133

126134
# Private vars set during validation
127-
_num_bits: Optional[int] = PrivateAttr(default=None)
128-
_symmetric: Optional[bool] = PrivateAttr(default=None)
129-
_group_size: Optional[int] = PrivateAttr(default=None)
135+
_num_bits: int | None = PrivateAttr(default=None)
136+
_symmetric: bool | None = PrivateAttr(default=None)
137+
_group_size: int | None = PrivateAttr(default=None)
130138

131139
# Private vars set during initialization, cleared during finalization
132-
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
140+
_resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list)
133141
# Cache list of forward input args for each parent module, one dict for each batch
134-
_parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
142+
_parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr(
135143
default_factory=dict
136144
)
137145
# Dict[smooth layer name, (activation means, activation counts)]
138-
_smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
146+
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
139147
default_factory=dict
140148
)
141149

@@ -389,7 +397,7 @@ def _setup_activation_cache_hooks(self) -> None:
389397

390398
def cache_parent_kwargs_hook(
391399
module: torch.nn.Module,
392-
args: Tuple[torch.Tensor, ...],
400+
args: tuple[torch.Tensor, ...],
393401
kwargs,
394402
):
395403
values = inspect.signature(module.forward).bind(*args, **kwargs)
@@ -398,7 +406,7 @@ def cache_parent_kwargs_hook(
398406
def create_cache_smooth_activations_hook_fn(smooth_name):
399407
def cache_smooth_activations_hook(
400408
_module: torch.nn.Module,
401-
args: Tuple[torch.Tensor, ...],
409+
args: tuple[torch.Tensor, ...],
402410
_output: torch.Tensor,
403411
):
404412
self._smooth_activation_means[smooth_name] = _accumulate_mean(
@@ -559,13 +567,13 @@ def _smooth(module):
559567
v.batch_intermediates.clear()
560568
self._assert_all_activations_consumed()
561569

562-
def _run_samples(self, module: Module) -> List[torch.Tensor]:
570+
def _run_samples(self, module: Module) -> list[torch.Tensor]:
563571
outputs = [
564572
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
565573
]
566574
return [
567575
# If Tuple, assume that first argument is the input
568-
output[0] if isinstance(output, Tuple) else output
576+
output[0] if isinstance(output, tuple) else output
569577
for output in outputs
570578
]
571579

@@ -574,8 +582,8 @@ def _compute_best_scale(
574582
x_mean: torch.Tensor,
575583
w_mean: torch.Tensor,
576584
parent_module: torch.nn.Module,
577-
linears2scale: List[torch.nn.Linear],
578-
fp16_outputs: List[torch.Tensor],
585+
linears2scale: list[torch.nn.Linear],
586+
fp16_outputs: list[torch.Tensor],
579587
) -> torch.Tensor:
580588
"""
581589
Compute loss and select best scales
@@ -586,7 +594,6 @@ def _compute_best_scale(
586594
W: original weights in FP16 | layer
587595
s: per channel scaling factor | s^-1 * X
588596
"""
589-
n_grid = 20
590597
history = []
591598
best_ratio = -1
592599
best_scales = None
@@ -602,52 +609,61 @@ def _compute_best_scale(
602609
x_mean = x_mean.view(-1).to(device)
603610
w_mean = w_mean.view(-1).to(device)
604611

612+
if self.duo_scaling is None:
613+
# if self.duo_scaling is unsert, perform half the grid search with
614+
# duo_scaling off and half with duo_scaling on
615+
n_grid = int(self.n_grid / 2)
616+
duo_scalings = [False, True]
617+
else:
618+
n_grid = self.n_grid
619+
duo_scalings = [self.duo_scaling]
605620
for ratio in range(n_grid):
606-
# create new scales
607-
ratio = ratio / n_grid
608-
609-
# NOTE: s^-1 * x is fused here, according to paper
610-
if self.duo_scaling:
611-
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
612-
min=1e-4
613-
)
614-
else:
615-
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
616-
scales = scales / (scales.max() * scales.min()).sqrt()
617-
_scalesview = scales.view(1, -1).to(device)
618-
619-
# avoid scaling values that overflow
620-
scales[torch.isinf(scales)] = 1
621-
scales[torch.isnan(scales)] = 1
622-
623-
# Q(W * s)
624-
for linear in linears2scale:
625-
linear.weight.mul_(_scalesview)
626-
update_offload_parameter(
627-
linear,
628-
"weight",
629-
_pseudo_quantize_tensor(
630-
w=linear.weight.data,
631-
symmetric=self._symmetric,
632-
bit_width=self._num_bits,
633-
group_size=self._group_size,
634-
)[0]
635-
/ _scalesview,
636-
)
621+
for duo_scaling in duo_scalings:
622+
# create new scales
623+
ratio = ratio / n_grid
624+
625+
# NOTE: s^-1 * x is fused here, according to paper
626+
if duo_scaling:
627+
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
628+
min=1e-4
629+
)
630+
else:
631+
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
632+
scales = scales / (scales.max() * scales.min()).sqrt()
633+
_scalesview = scales.view(1, -1).to(device)
634+
635+
# avoid scaling values that overflow
636+
scales[torch.isinf(scales)] = 1
637+
scales[torch.isnan(scales)] = 1
638+
639+
# Q(W * s)
640+
for linear in linears2scale:
641+
linear.weight.mul_(_scalesview)
642+
update_offload_parameter(
643+
linear,
644+
"weight",
645+
_pseudo_quantize_tensor(
646+
w=linear.weight.data,
647+
symmetric=self._symmetric,
648+
bit_width=self._num_bits,
649+
group_size=self._group_size,
650+
)[0]
651+
/ _scalesview,
652+
)
637653

638-
# W * X
639-
int_w_outputs = self._run_samples(parent_module)
654+
# W * X
655+
int_w_outputs = self._run_samples(parent_module)
640656

641-
# compute mean squared error (L2 norm)
642-
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
657+
# compute mean squared error (L2 norm)
658+
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
643659

644-
history.append(loss)
645-
if loss < best_error:
646-
best_error = loss
647-
best_ratio = ratio
648-
best_scales = scales.clone()
660+
history.append(loss)
661+
if loss < best_error:
662+
best_error = loss
663+
best_ratio = ratio
664+
best_scales = scales.clone()
649665

650-
parent_module.load_state_dict(org_sd, strict=False)
666+
parent_module.load_state_dict(org_sd, strict=False)
651667

652668
if best_ratio == -1:
653669
logger.debug(history)
@@ -667,8 +683,8 @@ def _compute_best_scale(
667683
@torch.no_grad()
668684
def _compute_loss(
669685
self,
670-
fp16_outputs: List[torch.Tensor],
671-
int_w_outputs: List[torch.Tensor],
686+
fp16_outputs: list[torch.Tensor],
687+
int_w_outputs: list[torch.Tensor],
672688
device: torch.device,
673689
) -> torch.Tensor:
674690
loss = 0.0
@@ -746,8 +762,8 @@ def _pseudo_quantize_tensor(
746762

747763
def _accumulate_mean(
748764
inp: torch.Tensor,
749-
prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]],
750-
) -> Tuple[torch.FloatTensor, int]:
765+
prev_mean_and_count: tuple[torch.FloatTensor, int] | None,
766+
) -> tuple[torch.FloatTensor, int]:
751767
sum_added = inp.sum(dim=0)
752768
num_added = inp.size(0)
753769
if prev_mean_and_count is None:
@@ -761,7 +777,7 @@ def _accumulate_mean(
761777
return (prev_sum + sum_added) / new_count, new_count
762778

763779

764-
def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
780+
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
765781
"""
766782
Given a list of names, returns the lowest-scope common parent.
767783

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
from typing import Dict, List, Optional
32

43
from loguru import logger
54
from torch.nn import Module
@@ -143,7 +142,7 @@ class AWQMapping:
143142
# ["re:.*dense$"]
144143
# ),
145144
]
146-
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
145+
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
147146
"BloomForCausalLM": _bloom_mappings,
148147
"CohereForCausalLM": _cohere_mappings,
149148
"Cohere2ForCausalLM": _cohere_mappings,
@@ -186,13 +185,13 @@ class ResolvedMapping:
186185

187186
smooth_name: str
188187
smooth_layer: Module
189-
balance_layers: List[Module]
190-
balance_names: Optional[List[str]] = None
191-
parent: Optional[Module] = None
192-
parent_name: Optional[str] = None
188+
balance_layers: list[Module]
189+
balance_names: list[str]
190+
parent: Module
191+
parent_name: str
193192

194193

195-
def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
194+
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
196195
"""
197196
:param architecture: str: The architecture of the model
198197
:return: list: The layer mappings for the given architecture

0 commit comments

Comments
 (0)