Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from typing import Dict, List, Optional, Tuple, Union
from itertools import product
from typing import Literal

import torch
from compressed_tensors.quantization import disable_quantization
Expand Down Expand Up @@ -111,31 +112,40 @@ class AWQModifier(Modifier, QuantizationMixin):
device. Defaults to None, so cached args are not offloaded. Consider setting
to torch.device("cpu") if you are encountering OOM errors
:param duo_scaling: whether to use duo scaling, which uses both input activations
and weights to determine the scaling factor
and weights to determine the scaling factor. Defaults to True
If True, both activations and weights are used.
If False, only activations are used.
If "both", half the grid search is performed with duo_scaling=False and the
other half is performed with duo_scaling=True.
:param n_grid: when performing the best scales grid search for each mapping,
this specifies how many grid points should be used. To decrease the runtime,
at the possible cost of slightly worse scales, this can be decreased.
Defaults to 20
"""

# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: Union[str, List[str], None] = None
mappings: Optional[List[AWQMapping]] = None
offload_device: Optional[torch.device] = None
duo_scaling: bool = True
sequential_targets: str | list[str] | None = None
mappings: list[AWQMapping] | None = None
offload_device: torch.device | None = None
duo_scaling: bool | Literal["both"] = True
n_grid: int = 20

# Private vars set during validation
_num_bits: Optional[int] = PrivateAttr(default=None)
_symmetric: Optional[bool] = PrivateAttr(default=None)
_group_size: Optional[int] = PrivateAttr(default=None)
_num_bits: int | None = PrivateAttr(default=None)
_symmetric: bool | None = PrivateAttr(default=None)
_group_size: int | None = PrivateAttr(default=None)

# Private vars set during initialization, cleared during finalization
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
_resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list)
# Cache list of forward input args for each parent module, one dict for each batch
_parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
_parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)
# Dict[smooth layer name, (activation means, activation counts)]
_smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
default_factory=dict
)

Expand Down Expand Up @@ -389,7 +399,7 @@ def _setup_activation_cache_hooks(self) -> None:

def cache_parent_kwargs_hook(
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
args: tuple[torch.Tensor, ...],
kwargs,
):
values = inspect.signature(module.forward).bind(*args, **kwargs)
Expand All @@ -398,7 +408,7 @@ def cache_parent_kwargs_hook(
def create_cache_smooth_activations_hook_fn(smooth_name):
def cache_smooth_activations_hook(
_module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
args: tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
self._smooth_activation_means[smooth_name] = _accumulate_mean(
Expand Down Expand Up @@ -454,9 +464,11 @@ def _apply_smoothing(self, model: Module) -> None:
balance_layers = mapping.balance_layers
parent_module = mapping.parent

with align_modules(
[parent_module, smooth_layer, *balance_layers]
), calibration_forward_context(model), HooksMixin.disable_hooks():
with (
align_modules([parent_module, smooth_layer, *balance_layers]),
calibration_forward_context(model),
HooksMixin.disable_hooks(),
):
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
Expand Down Expand Up @@ -559,13 +571,13 @@ def _smooth(module):
v.batch_intermediates.clear()
self._assert_all_activations_consumed()

def _run_samples(self, module: Module) -> List[torch.Tensor]:
def _run_samples(self, module: Module) -> list[torch.Tensor]:
outputs = [
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
]
return [
# If Tuple, assume that first argument is the input
output[0] if isinstance(output, Tuple) else output
output[0] if isinstance(output, tuple) else output
for output in outputs
]

Expand All @@ -574,8 +586,8 @@ def _compute_best_scale(
x_mean: torch.Tensor,
w_mean: torch.Tensor,
parent_module: torch.nn.Module,
linears2scale: List[torch.nn.Linear],
fp16_outputs: List[torch.Tensor],
linears2scale: list[torch.nn.Linear],
fp16_outputs: list[torch.Tensor],
) -> torch.Tensor:
"""
Compute loss and select best scales
Expand All @@ -586,7 +598,6 @@ def _compute_best_scale(
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid = 20
history = []
best_ratio = -1
best_scales = None
Expand All @@ -602,12 +613,21 @@ def _compute_best_scale(
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)

for ratio in range(n_grid):
match self.duo_scaling:
# if self.duo_scaling is "both", perform half the grid search with
# duo_scaling off and half with duo_scaling on
case "both":
n_grid = int(self.n_grid / 2)
duo_scalings = [False, True]
case _:
n_grid = self.n_grid
duo_scalings = [self.duo_scaling]
for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
# create new scales
ratio = ratio / n_grid
ratio = grid_idx / n_grid

# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
if use_duo_scaling:
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
min=1e-4
)
Expand Down Expand Up @@ -667,8 +687,8 @@ def _compute_best_scale(
@torch.no_grad()
def _compute_loss(
self,
fp16_outputs: List[torch.Tensor],
int_w_outputs: List[torch.Tensor],
fp16_outputs: list[torch.Tensor],
int_w_outputs: list[torch.Tensor],
device: torch.device,
) -> torch.Tensor:
loss = 0.0
Expand Down Expand Up @@ -746,8 +766,8 @@ def _pseudo_quantize_tensor(

def _accumulate_mean(
inp: torch.Tensor,
prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]],
) -> Tuple[torch.FloatTensor, int]:
prev_mean_and_count: tuple[torch.FloatTensor, int] | None,
) -> tuple[torch.FloatTensor, int]:
sum_added = inp.sum(dim=0)
num_added = inp.size(0)
if prev_mean_and_count is None:
Expand All @@ -761,7 +781,7 @@ def _accumulate_mean(
return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent.

Expand Down
13 changes: 6 additions & 7 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from loguru import logger
from torch.nn import Module
Expand Down Expand Up @@ -143,7 +142,7 @@ class AWQMapping:
# ["re:.*dense$"]
# ),
]
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
"BloomForCausalLM": _bloom_mappings,
"CohereForCausalLM": _cohere_mappings,
"Cohere2ForCausalLM": _cohere_mappings,
Expand Down Expand Up @@ -186,13 +185,13 @@ class ResolvedMapping:

smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
balance_names: Optional[List[str]] = None
parent: Optional[Module] = None
parent_name: Optional[str] = None
balance_layers: list[Module]
balance_names: list[str]
parent: Module
parent_name: str


def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
"""
:param architecture: str: The architecture of the model
:return: list: The layer mappings for the given architecture
Expand Down
6 changes: 6 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def test_validate():
}
)

AWQModifier(scheme="W4A16", duo_scaling="both")
with pytest.raises(ValidationError):
AWQModifier(scheme="W4A16", duo_scaling="Both")
with pytest.raises(ValidationError):
AWQModifier(scheme="W4A16", duo_scaling="x")


@pytest.mark.unit
def test_get_lowest_common_parent():
Expand Down