From 2dfce99a51019dd565a8fc3d0557db7bd7df429a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 22 Oct 2025 12:17:14 -0400 Subject: [PATCH 01/14] WIP Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 181 ++++-------------- .../modifiers/quantization/calibration.py | 2 +- 2 files changed, 33 insertions(+), 150 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6bc97b446b..e869021441 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -21,7 +21,7 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) -from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.modifiers.quantization.calibration import call_observer, update_weight_zp_scale from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache @@ -123,11 +123,6 @@ class AWQModifier(Modifier, QuantizationMixin): offload_device: Optional[torch.device] = None duo_scaling: bool = True - # Private vars set during validation - _num_bits: Optional[int] = PrivateAttr(default=None) - _symmetric: Optional[bool] = PrivateAttr(default=None) - _group_size: Optional[int] = PrivateAttr(default=None) - # Private vars set during initialization, cleared during finalization _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch @@ -139,72 +134,6 @@ class AWQModifier(Modifier, QuantizationMixin): default_factory=dict ) - # NOTE: different name chosen to avoid collision with - # QuantizationMixin.validate_model_after, which must be called first - @model_validator(mode="after") - def validate_awq_after(model: "AWQModifier") -> "AWQModifier": - """ - Confirm only one configuration for group_size, symmetric, and num_bits, - as AWQ algorithm depends on it - Confirm no activation quantization, as AWQ only works with WNA16 - """ - config = model.resolve_quantization_config() - - num_bits_set = set( - group.weights.num_bits - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(num_bits_set) == 1 - ), "In AWQ, all config groups must use the same configuration for num_bits" - - model._num_bits = next(iter(num_bits_set)) - - symmetric_set = set( - group.weights.symmetric - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(symmetric_set) == 1 - ), "In AWQ, all config groups must use the same configuration for symmetric" - - model._symmetric = next(iter(symmetric_set)) - - group_size_set = set( - group.weights.group_size - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(group_size_set) == 1 - ), "In AWQ, all config groups must use the same configuration for group_size" - - model._group_size = next(iter(group_size_set)) - - in_num_bits_set = set( - group.input_activations.num_bits - for group in config.config_groups.values() - if group.input_activations is not None - ) - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"input activations {in_num_bits_set} not allowed" - ) - - out_num_bits_set = set( - group.output_activations.num_bits - for group in config.config_groups.values() - if group.output_activations is not None - ) - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"output activations {out_num_bits_set} not allowed" - ) - - return model - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize AWQ on the given state @@ -455,23 +384,6 @@ def _apply_smoothing(self, model: Module) -> None: with align_modules( [parent_module, smooth_layer, *balance_layers] ), calibration_forward_context(model), HooksMixin.disable_hooks(): - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self._group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = weight.mean(0) - del weight - # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here fp16_outputs = self._run_samples(parent_module) @@ -498,11 +410,9 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - x_mean = self._smooth_activation_means[mapping.smooth_name][0] - # [STEP 4]: Compute loss best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs + parent_module, mapping, fp16_outputs ) @torch.no_grad() @@ -566,10 +476,8 @@ def _run_samples(self, module: Module) -> List[torch.Tensor]: def _compute_best_scale( self, - x_mean: torch.Tensor, - w_mean: torch.Tensor, parent_module: torch.nn.Module, - linears2scale: List[torch.nn.Linear], + mapping: ResolvedMapping, fp16_outputs: List[torch.Tensor], ) -> torch.Tensor: """ @@ -587,6 +495,8 @@ def _compute_best_scale( best_scales = None best_error = float("inf") + linears2scale = mapping.balance_layers + org_sd = { k: v.cpu() for k, v in parent_module.state_dict().items() @@ -594,8 +504,9 @@ def _compute_best_scale( } device = get_execution_device(parent_module) - x_mean = x_mean.view(-1).to(device) - w_mean = w_mean.view(-1).to(device) + + if self.duo_scaling: + x_mean, w_mean = self._compute_duo_scaling_means(mapping) for ratio in range(n_grid): # create new scales @@ -618,17 +529,8 @@ def _compute_best_scale( # Q(W * s) for linear in linears2scale: linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, - ) + call_observer(linear, "weight", linear.weight) # assert is memoryless observer + linear.weight.div_(_scalesview) # W * X int_w_outputs = self._run_samples(parent_module) @@ -696,47 +598,28 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") -def _pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, ( - f"org_w_shape ({org_w_shape[-1]}) must be a multiple " - + f"of group_size ({group_size})!" - ) - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros + def _compute_duo_scaling_means(self, mapping: ResolvedMapping): + balance_layers = mapping.balance_layers + + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + weight = weight.view(-1, self._group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + # Resizes the rescaled weight matrix back up to its original dimensions + weight = weight.view(org_shape) + # Gets the average rescaled magnitude for each output channel + w_mean = weight.mean(0) + + x_mean = self._smooth_activation_means[mapping.smooth_name][0] + + return x_mean, w_mean def _accumulate_mean( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 5540532c97..f47d1d8fee 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -83,7 +83,7 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - value = module.weight if base_name == "weight" else value + value = value or (module.weight if base_name == "weight" else value) observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: From 2ea64280020ac3e5b30114868f3921fea200ebd7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 22 Oct 2025 12:22:29 -0400 Subject: [PATCH 02/14] add todo Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e869021441..0ad0d3025a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -21,6 +21,7 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) +from llmcompressor.observers.helpers import _flatten_weight from llmcompressor.modifiers.quantization.calibration import call_observer, update_weight_zp_scale from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -601,6 +602,10 @@ def _assert_all_activations_consumed(self): def _compute_duo_scaling_means(self, mapping: ResolvedMapping): balance_layers = mapping.balance_layers + # TODO: validate that all layers have the same quantization_scheme.weights + # either generalize this to compute means with different strategy shapes + # or throw error if strategy is not channel/group + # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together weight = torch.cat([bl.weight for bl in balance_layers], dim=0) From 8174d06bedbcabe642233eefabadee683f5cf5bc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 22 Oct 2025 12:23:58 -0400 Subject: [PATCH 03/14] forward quantize Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/awq/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0ad0d3025a..4fdb5c6c86 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import disable_quantization, forward_quantize from compressed_tensors.utils import ( align_modules, get_execution_device, @@ -531,6 +531,7 @@ def _compute_best_scale( for linear in linears2scale: linear.weight.mul_(_scalesview) call_observer(linear, "weight", linear.weight) # assert is memoryless observer + linear.weight = forward_quantize(linear.weight) linear.weight.div_(_scalesview) # W * X From 6d6d382b62ce9a68c16a65d745fa5acfe64560b2 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 20:41:38 +0000 Subject: [PATCH 04/14] more updates Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 60 ++++++++++++++----------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 15aea9c27b..ac20a533f4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -11,7 +11,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import ConfigDict, PrivateAttr +from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module from tqdm import tqdm @@ -22,6 +22,7 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) +from llmcompressor.observers.helpers import _flatten_weight from llmcompressor.modifiers.quantization.calibration import ( call_observer, update_weight_zp_scale, @@ -399,6 +400,7 @@ def _apply_smoothing(self, model: Module) -> None: calibration_forward_context(model), HooksMixin.disable_hooks(), ): + # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here fp16_outputs = self._run_samples(parent_module) @@ -426,9 +428,7 @@ def _apply_smoothing(self, model: Module) -> None: continue # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - parent_module, mapping, fp16_outputs - ) + best_scales = self._compute_best_scale(mapping, fp16_outputs) @torch.no_grad() def _smooth(module): @@ -484,43 +484,48 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]: module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ - # If Tuple, assume that first argument is the input + # If tuple, assume that first argument is the input output[0] if isinstance(output, tuple) else output for output in outputs ] def _compute_best_scale( self, - parent_module: torch.nn.Module, mapping: ResolvedMapping, - fp16_outputs: List[torch.Tensor], + fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ - Compute loss and select best scales + Select best scales for a given mapping in a grid search + Best scales are those that minimize MSE loss of quantized weight + outputs compared to fp16_outputs L(s) = || Q(W * s) (s^-1 * X) - W * X || Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X + + :param mapping: best scales will be found for thi ResolvedMapping. + :param fp16_outputs: output of mapping.parent in unquantized case, + one tensor for each batch. + :return: tensor of best scales, one for each channel """ history = [] best_ratio = -1 best_scales = None best_error = float("inf") - linears2scale = mapping.balance_layers - org_sd = { k: v.cpu() - for k, v in parent_module.state_dict().items() + for k, v in mapping.parent.state_dict().items() if v.device != torch.device("meta") } - device = get_execution_device(parent_module) + device = get_execution_device(mapping.parent) + x_mean = self._smooth_activation_means[mapping.smooth_name][0] if self.duo_scaling: - x_mean, w_mean = self._compute_duo_scaling_means(mapping) + w_mean = self._compute_layer_means(mapping.balance_layers) match self.duo_scaling: # if self.duo_scaling is "both", perform half the grid search with @@ -531,6 +536,7 @@ def _compute_best_scale( case _: n_grid = self.n_grid duo_scalings = [self.duo_scaling] + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): # create new scales ratio = grid_idx / n_grid @@ -550,16 +556,16 @@ def _compute_best_scale( scales[torch.isnan(scales)] = 1 # Q(W * s) - for linear in linears2scale: - linear.weight.mul_(_scalesview) + for balance_layer in mapping.balance_layers: + balance_layer.weight.mul_(_scalesview) call_observer( - linear, "weight", linear.weight + balance_layer, "weight", balance_layer.weight ) # assert is memoryless observer - linear.weight = forward_quantize(linear.weight) - linear.weight.div_(_scalesview) + balance_layer.weight = forward_quantize(balance_layer.weight) + balance_layer.weight.div_(_scalesview) # W * X - int_w_outputs = self._run_samples(parent_module) + int_w_outputs = self._run_samples(mapping.parent) # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_outputs, int_w_outputs, device) @@ -570,7 +576,7 @@ def _compute_best_scale( best_ratio = ratio best_scales = scales.clone() - parent_module.load_state_dict(org_sd, strict=False) + mapping.parent.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) @@ -593,7 +599,7 @@ def _compute_loss( fp16_outputs: list[torch.Tensor], int_w_outputs: list[torch.Tensor], device: torch.device, - ) -> torch.Tensor: + ) -> float: loss = 0.0 num_elements = 0 @@ -623,19 +629,21 @@ def _assert_all_activations_consumed(self): if len(self._smooth_activation_means) != 0: raise RuntimeError("Some cached activations were not used") - def _compute_duo_scaling_means(self, mapping: ResolvedMapping): - balance_layers = mapping.balance_layers + def _compute_layer_means( + self, balance_layers: list[torch.nn.Module] + ) -> torch.Tensor: # TODO: validate that all layers have the same quantization_scheme.weights # either generalize this to compute means with different strategy shapes # or throw error if strategy is not channel/group + _group_size = 128 # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together weight = torch.cat([bl.weight for bl in balance_layers], dim=0) org_shape = weight.shape # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self._group_size) + weight = weight.view(-1, _group_size) # Calculates the relative magnitude of the weights within # each of the quantization groups, and rescales each group # individually so that each group has weights on a 0-1 scale. @@ -646,9 +654,7 @@ def _compute_duo_scaling_means(self, mapping: ResolvedMapping): # Gets the average rescaled magnitude for each output channel w_mean = weight.mean(0) - x_mean = self._smooth_activation_means[mapping.smooth_name][0] - - return x_mean, w_mean + return w_mean def _accumulate_mean( From 743f4dc684c2b9d1a4fd6811cb5065133866f93a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 22:50:25 +0000 Subject: [PATCH 05/14] working Signed-off-by: Brian Dellabetta --- examples/awq/llama_example.py | 4 +- src/llmcompressor/modifiers/awq/base.py | 214 +++++++++++------- .../modifiers/quantization/calibration.py | 3 +- 3 files changed, 141 insertions(+), 80 deletions(-) diff --git a/examples/awq/llama_example.py b/examples/awq/llama_example.py index d06a2ccb91..e31304b293 100644 --- a/examples/awq/llama_example.py +++ b/examples/awq/llama_example.py @@ -50,7 +50,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. recipe = [ - AWQModifier(ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"]), + AWQModifier( + ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both" + ), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ac20a533f4..ced433a912 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -3,13 +3,20 @@ from typing import Literal import torch -from compressed_tensors.quantization import disable_quantization, forward_quantize +from compressed_tensors.quantization import ( + disable_quantization, + forward_quantize, + QuantizationStrategy, +) from compressed_tensors.utils import ( align_modules, get_execution_device, match_named_modules, update_offload_parameter, + patch_attrs, ) +from llmcompressor.observers.base import Observer + from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module @@ -22,7 +29,6 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) -from llmcompressor.observers.helpers import _flatten_weight from llmcompressor.modifiers.quantization.calibration import ( call_observer, update_weight_zp_scale, @@ -329,7 +335,7 @@ def _setup_activation_cache_hooks(self) -> None: """ def cache_parent_kwargs_hook( - module: torch.nn.Module, + module: Module, args: tuple[torch.Tensor, ...], kwargs, ): @@ -338,7 +344,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( - _module: torch.nn.Module, + _module: Module, args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): @@ -401,8 +407,7 @@ def _apply_smoothing(self, model: Module) -> None: HooksMixin.disable_hooks(), ): - # [STEP 3]: Compute output of module - # could cache from hook, rather than recomputing here + # Compute output of unquantized module fp16_outputs = self._run_samples(parent_module) if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): logger.info( @@ -427,11 +432,10 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - # [STEP 4]: Compute loss best_scales = self._compute_best_scale(mapping, fp16_outputs) @torch.no_grad() - def _smooth(module): + def _smooth(module: Module): scales = best_scales.to(module.weight.device) if module in balance_layers: update_offload_parameter( @@ -512,6 +516,7 @@ def _compute_best_scale( """ history = [] best_ratio = -1 + best_duo_scaling = -1 best_scales = None best_error = float("inf") @@ -523,9 +528,9 @@ def _compute_best_scale( device = get_execution_device(mapping.parent) - x_mean = self._smooth_activation_means[mapping.smooth_name][0] + x_mean = self._smooth_activation_means[mapping.smooth_name][0].to(device) if self.duo_scaling: - w_mean = self._compute_layer_means(mapping.balance_layers) + w_mean = self._compute_layer_means(mapping.balance_layers).to(device) match self.duo_scaling: # if self.duo_scaling is "both", perform half the grid search with @@ -537,46 +542,75 @@ def _compute_best_scale( n_grid = self.n_grid duo_scalings = [self.duo_scaling] - for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): - # create new scales - ratio = grid_idx / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if use_duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( - min=1e-4 + # Replace observers with memoryless_minmax for duration of grid search + with patch_attrs( + mapping.balance_layers, + "weight_observer", + [ + Observer.load_from_registry( + "memoryless_minmax", + base_name="weight", + args=balance_layer.quantization_scheme.weights, + module=balance_layer, ) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - _scalesview = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for balance_layer in mapping.balance_layers: - balance_layer.weight.mul_(_scalesview) - call_observer( - balance_layer, "weight", balance_layer.weight - ) # assert is memoryless observer - balance_layer.weight = forward_quantize(balance_layer.weight) - balance_layer.weight.div_(_scalesview) + for balance_layer in mapping.balance_layers + if hasattr(balance_layer, "quantization_scheme") + and hasattr(balance_layer.quantization_scheme, "weights") + ], + ): + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): + # create new scales + ratio = grid_idx / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if use_duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + _scalesview = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for balance_layer in mapping.balance_layers: + if not hasattr(balance_layer, "quantization_scheme") or not hasattr( + balance_layer.quantization_scheme, "weights" + ): + continue + + balance_layer.weight.mul_(_scalesview) + call_observer(balance_layer, "weight", balance_layer.weight) + update_offload_parameter( + balance_layer, + "weight", + forward_quantize( + balance_layer, + balance_layer.weight.data, + "weight", + balance_layer.quantization_scheme.weights, + ) + / _scalesview, + ) - # W * X - int_w_outputs = self._run_samples(mapping.parent) + # W * X + int_w_outputs = self._run_samples(mapping.parent) - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_outputs, int_w_outputs, device) - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() + history.append(loss) + if loss < best_error: + best_error = loss + best_duo_scaling = use_duo_scaling + best_ratio = ratio + best_scales = scales.clone() - mapping.parent.load_state_dict(org_sd, strict=False) + mapping.parent.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) @@ -605,14 +639,10 @@ def _compute_loss( # Compute the MSE loss for each batch for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): - batch_loss = ( - (fp16_batch.to(device) - int_w_batch.to(device)) - .view(-1) - .float() - .pow(2) - .sum() - .item() - ) + batch_loss = torch.nn.functional.mse_loss( + fp16_batch.to(device), int_w_batch.to(device) + ).item() + loss += batch_loss num_elements += fp16_batch.numel() @@ -629,32 +659,60 @@ def _assert_all_activations_consumed(self): if len(self._smooth_activation_means) != 0: raise RuntimeError("Some cached activations were not used") - def _compute_layer_means( - self, balance_layers: list[torch.nn.Module] - ) -> torch.Tensor: + def _compute_layer_means(self, layers: list[Module]) -> torch.Tensor: + """ + Compute per-channel mean of normalised weights for all passed in layers + Each layer is processed separately rather than copying all weights + into a single tensor, + """ + group_size = None + + # to calculate mean without having to carry full population + weight_total_count = 0 + weight_total_sum = None + + for layer in layers: + if not hasattr(layer, "weight"): + continue + + weight = layer.weight + org_shape = weight.shape + + group_size = _infer_group_size(layer) + + # The weights are reshaped to be organised by quantization group + if group_size > 0: + weight = weight.view(-1, group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + # Resizes the rescaled weight matrix back up to its original dimensions + weight = weight.view(org_shape) + + # Gets the average rescaled magnitude for each output channel + weight_total_count += weight.size(0) + weight_sum = weight.sum(0, dtype=torch.float64) + if weight_total_sum is None: + weight_total_sum = weight_sum + else: + weight_total_sum += weight_sum + + return weight_total_sum / weight_total_count + - # TODO: validate that all layers have the same quantization_scheme.weights - # either generalize this to compute means with different strategy shapes - # or throw error if strategy is not channel/group - _group_size = 128 - - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, _group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = weight.mean(0) - - return w_mean +def _infer_group_size(layer: Module) -> int: + """ + Returns group_size of layer if applicable, otherwise -1 + """ + if ( + hasattr(layer, "quantization_scheme") + and hasattr(layer.quantization_scheme, "weights") + and layer.quantization_scheme.weights.strategy == QuantizationStrategy.GROUP + ): + return layer.quantization_scheme.weights.group_size + return -1 def _accumulate_mean( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 16b31e939a..d71a2b2190 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -78,7 +78,8 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - value = value or (module.weight if base_name == "weight" else value) + if value is None and base_name == "weight": + value = module.weight observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: From c73806a8cfe9630aec86fb3a7dfeae62058b7ada Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 23:07:05 +0000 Subject: [PATCH 06/14] docstrings Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ced433a912..89f4b34ca6 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -625,6 +625,8 @@ def _compute_best_scale( torch.isnan(best_scales).sum() == 0 ), f"Nan found in scales: {best_scales}" + print("BEST CONFIGURATION", best_duo_scaling, best_ratio) + return best_scales.detach().cpu() @torch.no_grad() @@ -661,9 +663,12 @@ def _assert_all_activations_consumed(self): def _compute_layer_means(self, layers: list[Module]) -> torch.Tensor: """ - Compute per-channel mean of normalised weights for all passed in layers - Each layer is processed separately rather than copying all weights - into a single tensor, + Compute per-channel mean of normalised weights for all passed in layers. + Layers with group-wise quantization will be normalized against the group + abs max instead of the abs max of the channel. + + To minimize memory requirements, layers are reduced to a running total + of sums and counts when calculating mean """ group_size = None @@ -678,17 +683,15 @@ def _compute_layer_means(self, layers: list[Module]) -> torch.Tensor: weight = layer.weight org_shape = weight.shape - group_size = _infer_group_size(layer) - - # The weights are reshaped to be organised by quantization group - if group_size > 0: + # If group-wise, calculate abs max based on group + # abs max, rather than channel + if (group_size := _infer_group_size(layer)) > 0: weight = weight.view(-1, group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. + weight.abs_() weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions + + # Reshape back to original dimensions weight = weight.view(org_shape) # Gets the average rescaled magnitude for each output channel From 71a961e8e8fb2003d66456cc52c530569ff95960 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 23:13:59 +0000 Subject: [PATCH 07/14] touchup Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 89f4b34ca6..a711f16aab 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -625,8 +625,6 @@ def _compute_best_scale( torch.isnan(best_scales).sum() == 0 ), f"Nan found in scales: {best_scales}" - print("BEST CONFIGURATION", best_duo_scaling, best_ratio) - return best_scales.detach().cpu() @torch.no_grad() From 842ff1cddb2a9637cad3833a81a4d2d3c1555185 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 13 Nov 2025 23:31:05 +0000 Subject: [PATCH 08/14] formatting Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index a711f16aab..ded74a1da7 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -4,21 +4,19 @@ import torch from compressed_tensors.quantization import ( + QuantizationStrategy, disable_quantization, forward_quantize, - QuantizationStrategy, ) from compressed_tensors.utils import ( align_modules, get_execution_device, match_named_modules, - update_offload_parameter, patch_attrs, + update_offload_parameter, ) -from llmcompressor.observers.base import Observer - from loguru import logger -from pydantic import ConfigDict, PrivateAttr, model_validator +from pydantic import ConfigDict, PrivateAttr from torch.nn import Module from tqdm import tqdm @@ -35,6 +33,7 @@ ) from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.observers.base import Observer from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context @@ -406,7 +405,6 @@ def _apply_smoothing(self, model: Module) -> None: calibration_forward_context(model), HooksMixin.disable_hooks(), ): - # Compute output of unquantized module fp16_outputs = self._run_samples(parent_module) if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): From 0ee2382ac3bc3ed4fd838f26be3d6fc04ff9515b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 18:42:06 +0000 Subject: [PATCH 09/14] unit test for compute layer means Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 5 +- .../llmcompressor/modifiers/awq/test_base.py | 73 ++++++++++++++++++- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ded74a1da7..51d1f5c3e1 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -623,6 +623,8 @@ def _compute_best_scale( torch.isnan(best_scales).sum() == 0 ), f"Nan found in scales: {best_scales}" + print("BEST CONFIGURATION", best_duo_scaling, best_ratio) + return best_scales.detach().cpu() @torch.no_grad() @@ -657,7 +659,8 @@ def _assert_all_activations_consumed(self): if len(self._smooth_activation_means) != 0: raise RuntimeError("Some cached activations were not used") - def _compute_layer_means(self, layers: list[Module]) -> torch.Tensor: + @staticmethod + def _compute_layer_means(layers: list[Module]) -> torch.Tensor: """ Compute per-channel mean of normalised weights for all passed in layers. Layers with group-wise quantization will be normalized against the group diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51a..d9492a6ed2 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,7 +1,12 @@ import pytest import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) from pydantic import ValidationError +from torch.testing import assert_close from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import get_lowest_common_parent @@ -234,3 +239,69 @@ def test_get_lowest_common_parent(): ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + +@torch.no_grad +@pytest.mark.unit +@pytest.mark.parametrize( + "n_balance_layers, group_size, n_input_features", + [ + (5, None, 32), + (4, 10, 40), + ], +) +def test_awq_compute_layer_means(n_balance_layers, group_size, n_input_features): + """ + Confirm our logic to compute duo_scaling layer means via a running tally + matches the original memory-intensive AutoAWQ implementation, which concats + all balance layers into a single tensor before reducing to mean + Large models were prone to fail at this step. + """ + balance_layers = [ + torch.nn.Linear(n_input_features, 10) for _ in range(n_balance_layers) + ] + for balance_layer in balance_layers: + setattr( + balance_layer, + "quantization_scheme", + QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=( + QuantizationStrategy.GROUP + if group_size is not None + else QuantizationStrategy.CHANNEL + ), + group_size=group_size, + ), + ), + ) + + ##### + ##### Original AutoAwq implementation + ##### + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + if group_size is not None: + weight = weight.view(-1, group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + weight = weight.view(org_shape) + # Gets the average rescaled magnitude for each output channel + w_mean_auto_awq = weight.mean(0) + del weight + ##### + ##### Original AutoAwq implementation + ##### + + w_mean_awq = AWQModifier._compute_layer_means(balance_layers).to( + w_mean_auto_awq.dtype + ) + + assert_close(w_mean_auto_awq, w_mean_awq) From e54daf8a634ea335cda63a05ebc11f116eef585d Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 18:49:11 +0000 Subject: [PATCH 10/14] improve validation logic in compute best scale Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 51d1f5c3e1..033f9a6ad4 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -540,9 +540,16 @@ def _compute_best_scale( n_grid = self.n_grid duo_scalings = [self.duo_scaling] - # Replace observers with memoryless_minmax for duration of grid search + # Where appropriate, replace observers with memoryless_minmax + # for duration of grid search + balance_layers_to_patch = [ + balance_layer + for balance_layer in mapping.balance_layers + if hasattr(balance_layer, "quantization_scheme") + and hasattr(balance_layer.quantization_scheme, "weights") + ] with patch_attrs( - mapping.balance_layers, + balance_layers_to_patch, "weight_observer", [ Observer.load_from_registry( @@ -551,9 +558,7 @@ def _compute_best_scale( args=balance_layer.quantization_scheme.weights, module=balance_layer, ) - for balance_layer in mapping.balance_layers - if hasattr(balance_layer, "quantization_scheme") - and hasattr(balance_layer.quantization_scheme, "weights") + for balance_layer in balance_layers_to_patch ], ): for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): @@ -575,7 +580,7 @@ def _compute_best_scale( scales[torch.isnan(scales)] = 1 # Q(W * s) - for balance_layer in mapping.balance_layers: + for balance_layer in balance_layers_to_patch: if not hasattr(balance_layer, "quantization_scheme") or not hasattr( balance_layer.quantization_scheme, "weights" ): From 554442cd224be59297ef5ae3ccdc8d90218605a2 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 18:53:33 +0000 Subject: [PATCH 11/14] add block-wise TODO Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 033f9a6ad4..57c29d592a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -674,6 +674,8 @@ def _compute_layer_means(layers: list[Module]) -> torch.Tensor: To minimize memory requirements, layers are reduced to a running total of sums and counts when calculating mean """ + # TODO: allow for block-wise layer means as well + group_size = None # to calculate mean without having to carry full population From 1c6696a7dcd95e9072f4fa679c4e6757c6c28664 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 20:57:57 +0000 Subject: [PATCH 12/14] minor cleanup Signed-off-by: Brian Dellabetta --- .../llmcompressor/modifiers/awq/test_base.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index d9492a6ed2..65c7a50dd3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -250,7 +250,7 @@ def test_get_lowest_common_parent(): (4, 10, 40), ], ) -def test_awq_compute_layer_means(n_balance_layers, group_size, n_input_features): +def test_compute_layer_means(n_balance_layers, group_size, n_input_features): """ Confirm our logic to compute duo_scaling layer means via a running tally matches the original memory-intensive AutoAWQ implementation, which concats @@ -277,28 +277,27 @@ def test_awq_compute_layer_means(n_balance_layers, group_size, n_input_features) ), ) - ##### - ##### Original AutoAwq implementation - ##### - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - if group_size is not None: - weight = weight.view(-1, group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean_auto_awq = weight.mean(0) - del weight - ##### - ##### Original AutoAwq implementation - ##### + def _auto_awq_compute_layer_means(layers: list[torch.nn.Module]) -> torch.Tensor: + """ + Original AutoAwq implementation + """ + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + if group_size is not None: + weight = weight.view(-1, group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + weight = weight.view(org_shape) + # Gets the average rescaled magnitude for each output channel + return weight.mean(0) + + w_mean_auto_awq = _auto_awq_compute_layer_means(balance_layers) w_mean_awq = AWQModifier._compute_layer_means(balance_layers).to( w_mean_auto_awq.dtype From 65c3b2d527e403cf12cb092287404f7c4a55c4d8 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 21:00:24 +0000 Subject: [PATCH 13/14] remove validation tests Signed-off-by: Brian Dellabetta --- .../llmcompressor/modifiers/awq/test_base.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 65c7a50dd3..289fe24019 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -119,28 +119,26 @@ def test_set_resolved_mappings(): @pytest.mark.unit def test_validate(): - with pytest.raises(ValidationError): - AWQModifier(scheme="W8A8") + AWQModifier(scheme="W8A8") - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=64, - ), + AWQModifier( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=64, ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), + ), + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=128, ), - } - ) + ), + } + ) with pytest.raises(ValidationError): AWQModifier( From bdcdca416210aa3c4632c7d41c423c853cfdd2d5 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 14 Nov 2025 21:28:51 +0000 Subject: [PATCH 14/14] remove validation tests Signed-off-by: Brian Dellabetta --- .../llmcompressor/modifiers/awq/test_base.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 289fe24019..32bf9a490d 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -119,61 +119,6 @@ def test_set_resolved_mappings(): @pytest.mark.unit def test_validate(): - AWQModifier(scheme="W8A8") - - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=64, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - } - ) - - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=8, - group_size=128, - ), - ), - } - ) - - # valid configuration - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - } - ) - AWQModifier(scheme="W4A16", duo_scaling="both") with pytest.raises(ValidationError): AWQModifier(scheme="W4A16", duo_scaling="Both")