diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc86ba25f6..98e53b4e00 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,6 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union +from itertools import product +from typing import Literal import torch from compressed_tensors.quantization import disable_quantization @@ -111,31 +112,40 @@ class AWQModifier(Modifier, QuantizationMixin): device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors :param duo_scaling: whether to use duo scaling, which uses both input activations - and weights to determine the scaling factor + and weights to determine the scaling factor. Defaults to True + If True, both activations and weights are used. + If False, only activations are used. + If "both", half the grid search is performed with duo_scaling=False and the + other half is performed with duo_scaling=True. + :param n_grid: when performing the best scales grid search for each mapping, + this specifies how many grid points should be used. To decrease the runtime, + at the possible cost of slightly worse scales, this can be decreased. + Defaults to 20 """ # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) # User-provided vars (in addition to QuantizationMixin args) - sequential_targets: Union[str, List[str], None] = None - mappings: Optional[List[AWQMapping]] = None - offload_device: Optional[torch.device] = None - duo_scaling: bool = True + sequential_targets: str | list[str] | None = None + mappings: list[AWQMapping] | None = None + offload_device: torch.device | None = None + duo_scaling: bool | Literal["both"] = True + n_grid: int = 20 # Private vars set during validation - _num_bits: Optional[int] = PrivateAttr(default=None) - _symmetric: Optional[bool] = PrivateAttr(default=None) - _group_size: Optional[int] = PrivateAttr(default=None) + _num_bits: int | None = PrivateAttr(default=None) + _symmetric: bool | None = PrivateAttr(default=None) + _group_size: int | None = PrivateAttr(default=None) # Private vars set during initialization, cleared during finalization - _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) + _resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch - _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr( + _parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] - _smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr( + _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( default_factory=dict ) @@ -389,7 +399,7 @@ def _setup_activation_cache_hooks(self) -> None: def cache_parent_kwargs_hook( module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], kwargs, ): values = inspect.signature(module.forward).bind(*args, **kwargs) @@ -398,7 +408,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( _module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], + args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): self._smooth_activation_means[smooth_name] = _accumulate_mean( @@ -454,9 +464,11 @@ def _apply_smoothing(self, model: Module) -> None: balance_layers = mapping.balance_layers parent_module = mapping.parent - with align_modules( - [parent_module, smooth_layer, *balance_layers] - ), calibration_forward_context(model), HooksMixin.disable_hooks(): + with ( + align_modules([parent_module, smooth_layer, *balance_layers]), + calibration_forward_context(model), + HooksMixin.disable_hooks(), + ): # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together weight = torch.cat([bl.weight for bl in balance_layers], dim=0) @@ -559,13 +571,13 @@ def _smooth(module): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: + def _run_samples(self, module: Module) -> list[torch.Tensor]: outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output + output[0] if isinstance(output, tuple) else output for output in outputs ] @@ -574,8 +586,8 @@ def _compute_best_scale( x_mean: torch.Tensor, w_mean: torch.Tensor, parent_module: torch.nn.Module, - linears2scale: List[torch.nn.Linear], - fp16_outputs: List[torch.Tensor], + linears2scale: list[torch.nn.Linear], + fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ Compute loss and select best scales @@ -586,7 +598,6 @@ def _compute_best_scale( W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X """ - n_grid = 20 history = [] best_ratio = -1 best_scales = None @@ -602,12 +613,21 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - for ratio in range(n_grid): + match self.duo_scaling: + # if self.duo_scaling is "both", perform half the grid search with + # duo_scaling off and half with duo_scaling on + case "both": + n_grid = int(self.n_grid / 2) + duo_scalings = [False, True] + case _: + n_grid = self.n_grid + duo_scalings = [self.duo_scaling] + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): # create new scales - ratio = ratio / n_grid + ratio = grid_idx / n_grid # NOTE: s^-1 * x is fused here, according to paper - if self.duo_scaling: + if use_duo_scaling: scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( min=1e-4 ) @@ -667,8 +687,8 @@ def _compute_best_scale( @torch.no_grad() def _compute_loss( self, - fp16_outputs: List[torch.Tensor], - int_w_outputs: List[torch.Tensor], + fp16_outputs: list[torch.Tensor], + int_w_outputs: list[torch.Tensor], device: torch.device, ) -> torch.Tensor: loss = 0.0 @@ -746,8 +766,8 @@ def _pseudo_quantize_tensor( def _accumulate_mean( inp: torch.Tensor, - prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]], -) -> Tuple[torch.FloatTensor, int]: + prev_mean_and_count: tuple[torch.FloatTensor, int] | None, +) -> tuple[torch.FloatTensor, int]: sum_added = inp.sum(dim=0) num_added = inp.size(0) if prev_mean_and_count is None: @@ -761,7 +781,7 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]: +def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: """ Given a list of names, returns the lowest-scope common parent. diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 907bca4880..9bec035df4 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Dict, List, Optional from loguru import logger from torch.nn import Module @@ -143,7 +142,7 @@ class AWQMapping: # ["re:.*dense$"] # ), ] -AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { +AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = { "BloomForCausalLM": _bloom_mappings, "CohereForCausalLM": _cohere_mappings, "Cohere2ForCausalLM": _cohere_mappings, @@ -186,13 +185,13 @@ class ResolvedMapping: smooth_name: str smooth_layer: Module - balance_layers: List[Module] - balance_names: Optional[List[str]] = None - parent: Optional[Module] = None - parent_name: Optional[str] = None + balance_layers: list[Module] + balance_names: list[str] + parent: Module + parent_name: str -def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]: +def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]: """ :param architecture: str: The architecture of the model :return: list: The layer mappings for the given architecture diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index a66a278f32..950ab0f51a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -171,6 +171,12 @@ def test_validate(): } ) + AWQModifier(scheme="W4A16", duo_scaling="both") + with pytest.raises(ValidationError): + AWQModifier(scheme="W4A16", duo_scaling="Both") + with pytest.raises(ValidationError): + AWQModifier(scheme="W4A16", duo_scaling="x") + @pytest.mark.unit def test_get_lowest_common_parent():