From f94df6fd6b2c9ae9496c41d9c3291201f5e6f829 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 7 Nov 2025 21:19:00 +0000 Subject: [PATCH 1/5] modernized typehints, None duo scaling Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 146 +++++++++++--------- src/llmcompressor/modifiers/awq/mappings.py | 13 +- 2 files changed, 87 insertions(+), 72 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc86ba25f6..380f23394d 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,4 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union import torch from compressed_tensors.quantization import disable_quantization @@ -111,31 +110,40 @@ class AWQModifier(Modifier, QuantizationMixin): device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors :param duo_scaling: whether to use duo scaling, which uses both input activations - and weights to determine the scaling factor + and weights to determine the scaling factor. Defaults to None + If False, only activations are used. + If True, both activations and weights are used. + If None, half the grid search is performed with duo_scaling=False and the + other half is performed with duo_scaling=True. + :param n_grid: when performing the best scales grid search for each mapping, + this specifies how many grid points should be used. To decrease the runtime, + at the possible cost of slightly worse scales, this can be decreased. + Defaults to 20 """ # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) # User-provided vars (in addition to QuantizationMixin args) - sequential_targets: Union[str, List[str], None] = None - mappings: Optional[List[AWQMapping]] = None - offload_device: Optional[torch.device] = None - duo_scaling: bool = True + sequential_targets: str | list[str] | None = None + mappings: list[AWQMapping] | None = None + offload_device: torch.device | None = None + duo_scaling: bool | None = None + n_grid: int = 20 # Private vars set during validation - _num_bits: Optional[int] = PrivateAttr(default=None) - _symmetric: Optional[bool] = PrivateAttr(default=None) - _group_size: Optional[int] = PrivateAttr(default=None) + _num_bits: int | None = PrivateAttr(default=None) + _symmetric: bool | None = PrivateAttr(default=None) + _group_size: int | None = PrivateAttr(default=None) # Private vars set during initialization, cleared during finalization - _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) + _resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch - _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr( + _parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] - _smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr( + _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( default_factory=dict ) @@ -389,7 +397,7 @@ def _setup_activation_cache_hooks(self) -> None: def cache_parent_kwargs_hook( module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], kwargs, ): values = inspect.signature(module.forward).bind(*args, **kwargs) @@ -398,7 +406,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( _module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): self._smooth_activation_means[smooth_name] = _accumulate_mean( @@ -559,13 +567,13 @@ def _smooth(module): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: + def _run_samples(self, module: Module) -> list[torch.Tensor]: outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output + output[0] if isinstance(output, tuple) else output for output in outputs ] @@ -574,8 +582,8 @@ def _compute_best_scale( x_mean: torch.Tensor, w_mean: torch.Tensor, parent_module: torch.nn.Module, - linears2scale: List[torch.nn.Linear], - fp16_outputs: List[torch.Tensor], + linears2scale: list[torch.nn.Linear], + fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ Compute loss and select best scales @@ -586,7 +594,6 @@ def _compute_best_scale( W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X """ - n_grid = 20 history = [] best_ratio = -1 best_scales = None @@ -602,52 +609,61 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) + if self.duo_scaling is None: + # if self.duo_scaling is unsert, perform half the grid search with + # duo_scaling off and half with duo_scaling on + n_grid = int(self.n_grid / 2) + duo_scalings = [False, True] + else: + n_grid = self.n_grid + duo_scalings = [self.duo_scaling] for ratio in range(n_grid): - # create new scales - ratio = ratio / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if self.duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( - min=1e-4 - ) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - _scalesview = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for linear in linears2scale: - linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, - ) + for duo_scaling in duo_scalings: + # create new scales + ratio = ratio / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + _scalesview = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for linear in linears2scale: + linear.weight.mul_(_scalesview) + update_offload_parameter( + linear, + "weight", + _pseudo_quantize_tensor( + w=linear.weight.data, + symmetric=self._symmetric, + bit_width=self._num_bits, + group_size=self._group_size, + )[0] + / _scalesview, + ) - # W * X - int_w_outputs = self._run_samples(parent_module) + # W * X + int_w_outputs = self._run_samples(parent_module) - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_outputs, int_w_outputs, device) - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() - parent_module.load_state_dict(org_sd, strict=False) + parent_module.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) @@ -667,8 +683,8 @@ def _compute_best_scale( @torch.no_grad() def _compute_loss( self, - fp16_outputs: List[torch.Tensor], - int_w_outputs: List[torch.Tensor], + fp16_outputs: list[torch.Tensor], + int_w_outputs: list[torch.Tensor], device: torch.device, ) -> torch.Tensor: loss = 0.0 @@ -746,8 +762,8 @@ def _pseudo_quantize_tensor( def _accumulate_mean( inp: torch.Tensor, - prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]], -) -> Tuple[torch.FloatTensor, int]: + prev_mean_and_count: tuple[torch.FloatTensor, int] | None, +) -> tuple[torch.FloatTensor, int]: sum_added = inp.sum(dim=0) num_added = inp.size(0) if prev_mean_and_count is None: @@ -761,7 +777,7 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]: +def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: """ Given a list of names, returns the lowest-scope common parent. diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 907bca4880..9bec035df4 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Dict, List, Optional from loguru import logger from torch.nn import Module @@ -143,7 +142,7 @@ class AWQMapping: # ["re:.*dense$"] # ), ] -AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { +AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = { "BloomForCausalLM": _bloom_mappings, "CohereForCausalLM": _cohere_mappings, "Cohere2ForCausalLM": _cohere_mappings, @@ -186,13 +185,13 @@ class ResolvedMapping: smooth_name: str smooth_layer: Module - balance_layers: List[Module] - balance_names: Optional[List[str]] = None - parent: Optional[Module] = None - parent_name: Optional[str] = None + balance_layers: list[Module] + balance_names: list[str] + parent: Module + parent_name: str -def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]: +def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]: """ :param architecture: str: The architecture of the model :return: list: The layer mappings for the given architecture From ac6f6d249e86b16b20c152ac450c7cc3eaa51eea Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 7 Nov 2025 21:29:45 +0000 Subject: [PATCH 2/5] use itertools.product Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 88 ++++++++++++------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 380f23394d..6fa32bb110 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,4 +1,5 @@ import inspect +from itertools import product import torch from compressed_tensors.quantization import disable_quantization @@ -610,60 +611,59 @@ def _compute_best_scale( w_mean = w_mean.view(-1).to(device) if self.duo_scaling is None: - # if self.duo_scaling is unsert, perform half the grid search with + # if self.duo_scaling is unset, perform half the grid search with # duo_scaling off and half with duo_scaling on n_grid = int(self.n_grid / 2) duo_scalings = [False, True] else: n_grid = self.n_grid duo_scalings = [self.duo_scaling] - for ratio in range(n_grid): - for duo_scaling in duo_scalings: - # create new scales - ratio = ratio / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( - min=1e-4 - ) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - _scalesview = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for linear in linears2scale: - linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, - ) + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): + # create new scales + ratio = grid_idx / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if use_duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + _scalesview = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for linear in linears2scale: + linear.weight.mul_(_scalesview) + update_offload_parameter( + linear, + "weight", + _pseudo_quantize_tensor( + w=linear.weight.data, + symmetric=self._symmetric, + bit_width=self._num_bits, + group_size=self._group_size, + )[0] + / _scalesview, + ) - # W * X - int_w_outputs = self._run_samples(parent_module) + # W * X + int_w_outputs = self._run_samples(parent_module) - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_outputs, int_w_outputs, device) - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() - parent_module.load_state_dict(org_sd, strict=False) + parent_module.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) From f8e71da710c4faa8574766660e076afbc3e6234d Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Nov 2025 16:44:26 +0000 Subject: [PATCH 3/5] change default back to previous Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6fa32bb110..9a0eebb187 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -111,9 +111,9 @@ class AWQModifier(Modifier, QuantizationMixin): device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors :param duo_scaling: whether to use duo scaling, which uses both input activations - and weights to determine the scaling factor. Defaults to None - If False, only activations are used. + and weights to determine the scaling factor. Defaults to True If True, both activations and weights are used. + If False, only activations are used. If None, half the grid search is performed with duo_scaling=False and the other half is performed with duo_scaling=True. :param n_grid: when performing the best scales grid search for each mapping, @@ -129,7 +129,7 @@ class AWQModifier(Modifier, QuantizationMixin): sequential_targets: str | list[str] | None = None mappings: list[AWQMapping] | None = None offload_device: torch.device | None = None - duo_scaling: bool | None = None + duo_scaling: bool | None = True n_grid: int = 20 # Private vars set during validation From c6dca38c058833968964a7cde5000e34bf521063 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 16:22:00 +0000 Subject: [PATCH 4/5] switch duo_scaling None to both Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 7 ++++--- tests/llmcompressor/modifiers/awq/test_base.py | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 9a0eebb187..14e48876f5 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,6 @@ import inspect from itertools import product +from typing import Literal import torch from compressed_tensors.quantization import disable_quantization @@ -114,7 +115,7 @@ class AWQModifier(Modifier, QuantizationMixin): and weights to determine the scaling factor. Defaults to True If True, both activations and weights are used. If False, only activations are used. - If None, half the grid search is performed with duo_scaling=False and the + If "both", half the grid search is performed with duo_scaling=False and the other half is performed with duo_scaling=True. :param n_grid: when performing the best scales grid search for each mapping, this specifies how many grid points should be used. To decrease the runtime, @@ -129,7 +130,7 @@ class AWQModifier(Modifier, QuantizationMixin): sequential_targets: str | list[str] | None = None mappings: list[AWQMapping] | None = None offload_device: torch.device | None = None - duo_scaling: bool | None = True + duo_scaling: bool | Literal["both"] = True n_grid: int = 20 # Private vars set during validation @@ -610,7 +611,7 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - if self.duo_scaling is None: + if self.duo_scaling.lower() == "both": # if self.duo_scaling is unset, perform half the grid search with # duo_scaling off and half with duo_scaling on n_grid = int(self.n_grid / 2) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index a66a278f32..950ab0f51a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -171,6 +171,12 @@ def test_validate(): } ) + AWQModifier(scheme="W4A16", duo_scaling="both") + with pytest.raises(ValidationError): + AWQModifier(scheme="W4A16", duo_scaling="Both") + with pytest.raises(ValidationError): + AWQModifier(scheme="W4A16", duo_scaling="x") + @pytest.mark.unit def test_get_lowest_common_parent(): From fc43ad07a032808596e0bc044dc15464dd2ea073 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 16:26:22 +0000 Subject: [PATCH 5/5] switch to match statement Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 14e48876f5..98e53b4e00 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -464,9 +464,11 @@ def _apply_smoothing(self, model: Module) -> None: balance_layers = mapping.balance_layers parent_module = mapping.parent - with align_modules( - [parent_module, smooth_layer, *balance_layers] - ), calibration_forward_context(model), HooksMixin.disable_hooks(): + with ( + align_modules([parent_module, smooth_layer, *balance_layers]), + calibration_forward_context(model), + HooksMixin.disable_hooks(), + ): # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together weight = torch.cat([bl.weight for bl in balance_layers], dim=0) @@ -611,14 +613,15 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - if self.duo_scaling.lower() == "both": - # if self.duo_scaling is unset, perform half the grid search with + match self.duo_scaling: + # if self.duo_scaling is "both", perform half the grid search with # duo_scaling off and half with duo_scaling on - n_grid = int(self.n_grid / 2) - duo_scalings = [False, True] - else: - n_grid = self.n_grid - duo_scalings = [self.duo_scaling] + case "both": + n_grid = int(self.n_grid / 2) + duo_scalings = [False, True] + case _: + n_grid = self.n_grid + duo_scalings = [self.duo_scaling] for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): # create new scales ratio = grid_idx / n_grid