From 79c7e86b2b2171bd652a3ed18c0da9ab7d675998 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 17:41:44 -0400 Subject: [PATCH 01/23] refactor observers Signed-off-by: Kyle Sayers --- docs/observers.md | 6 +- .../modifiers/quantization/cache.py | 10 +- .../modifiers/quantization/calibration.py | 75 +--- .../quantization/gptq/gptq_quantize.py | 6 +- src/llmcompressor/observers/base.py | 320 +++-------------- src/llmcompressor/observers/helpers.py | 128 ++++++- src/llmcompressor/observers/min_max.py | 163 ++------- src/llmcompressor/observers/mse.py | 215 ++++------- .../modifiers/calibration/test_cache.py | 2 +- .../modifiers/calibration/test_lifecycle.py | 337 ++++++++++++++++++ .../modifiers/calibration/test_observers.py | 23 +- tests/llmcompressor/observers/test_helpers.py | 129 +++---- tests/llmcompressor/observers/test_min_max.py | 29 +- tests/llmcompressor/observers/test_mse.py | 39 +- 14 files changed, 733 insertions(+), 749 deletions(-) create mode 100644 tests/llmcompressor/modifiers/calibration/test_lifecycle.py diff --git a/docs/observers.md b/docs/observers.md index 342c7dec9a..c5dd978a59 100644 --- a/docs/observers.md +++ b/docs/observers.md @@ -65,7 +65,11 @@ from llmcompressor.observers import Observer from compressed_tensors.quantization.quant_args import QuantizationArgs args = QuantizationArgs(num_bits=4, strategy="group", group_size=128) -observer = Observer.load_from_registry("minmax", quantization_args=args) +observer = Observer.load_from_registry( + "minmax", + base_name="weight", + quantization_args=args, +) x = torch.randn(64, 512) scale, zero_point = observer(x) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index b09b418127..53eca8d075 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -86,13 +86,15 @@ def update( """ if len(self.k_observers) <= layer_idx: - k_observer_name = self.quantization_args.observer k_observer = Observer.load_from_registry( - k_observer_name, quantization_args=self.quantization_args + self.quantization_args.observer, + base_name="k", + args=self.quantization_args, ) - v_observer_name = self.quantization_args.observer v_observer = Observer.load_from_registry( - v_observer_name, quantization_args=self.quantization_args + self.quantization_args.observer, + base_name="v", + args=self.quantization_args, ) # NOTE: User may ignore some layers in configuration, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 96b400d63e..5540532c97 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -5,6 +5,7 @@ from compressed_tensors.quantization import ( DynamicType, KVCacheScaleType, + QuantizationArgs, QuantizationScheme, QuantizationStatus, QuantizationStrategy, @@ -19,12 +20,6 @@ from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain -DEFAULT_MAXSHRINK = 0.20 -DEFAULT_PATIENCE = 5 -DEFAULT_AVERAGING_CONSTANT = 0.01 -DEFAULT_GRID = 100.0 -DEFAULT_NORM = 2.4 - __all__ = [ "initialize_observer", "update_weight_zp_scale", @@ -54,31 +49,19 @@ def initialize_observer( :param base_name: str used to name the observer attribute """ - - arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - # no quantization scheme nothing to do - return - - quantization_args = getattr(quantization_scheme, arg_name, None) - # dont need observers for dynamic - if quantization_args is not None and quantization_args.dynamic in ( - False, - DynamicType.LOCAL, - ): - observer_kwargs = quantization_args.observer_kwargs or {} + if base_name == "weight": + arg_name = "weights" + elif base_name == "output": + arg_name = "output_activations" + else: # input, q, k, v + arg_name = "input_activations" + + args: QuantizationArgs = getattr_chain( + module, f"quantization_scheme.{arg_name}", None + ) + if args is not None and args.dynamic is not True: observer = Observer.load_from_registry( - quantization_args.observer, - quantization_args=quantization_args, - averaging_constant=observer_kwargs.get( - "averaging_constant", DEFAULT_AVERAGING_CONSTANT - ), - # used by mse observer only, will be ignored by minmax observer - maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), - patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), - grid=observer_kwargs.get("grid", DEFAULT_GRID), - norm=observer_kwargs.get("norm", DEFAULT_NORM), + args.observer, base_name=base_name, args=args, module=module ) module.register_module(f"{base_name}_observer", observer) @@ -100,36 +83,17 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - if base_name == "weight": - value = module.weight - g_idx = getattr(module, "weight_g_idx", None) - elif value is not None: - g_idx = None - else: - raise ValueError( - "Must provide a value to observe if not using weight observer" - ) - - observer = getattr(module, f"{base_name}_observer") + value = module.weight if base_name == "weight" else value + observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: - global_scale = observer( - value, - should_calculate_gparam=True, - ) + global_scale = observer.get_global_scale(value) update_offload_parameter(module, f"{base_name}_global_scale", global_scale) - else: - global_scale = getattr(module, f"{base_name}_global_scale", None) if should_calculate_qparams: - updated_scale, updated_zero_point = observer( - value, g_idx=g_idx, global_scale=global_scale - ) - # register or update scale & zero_point parameters (supports block shapes) - scale_name = f"{base_name}_scale" - zp_name = f"{base_name}_zero_point" - update_offload_parameter(module, scale_name, updated_scale) - update_offload_parameter(module, zp_name, updated_zero_point) + scale, zero_point = observer(value) + update_offload_parameter(module, f"{base_name}_scale", scale) + update_offload_parameter(module, f"{base_name}_zero_point", zero_point) def update_weight_global_scale(module: Module): @@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module): should_calculate_gparam=True, should_calculate_qparams=False, ) - module.weight_observer.reset() def update_weight_zp_scale(module: Module): diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 4392ed8cfd..28926650fe 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -95,8 +95,10 @@ def quantize_weight( # create observer for calculating quantization parameters observer = Observer.load_from_registry( - quant_args.observer, - quantization_args=quant_args, + "minmax", + base_name="weight", + args=quant_args, + module=module, averaging_constant=1.0, # ignore moving average ) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 6ca6e203c3..4af5c37e39 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -1,17 +1,16 @@ -from math import ceil -from typing import Any, Iterable, Optional, Tuple, Union +from abc import abstractmethod +from typing import Optional, Tuple +from weakref import ref import torch from compressed_tensors import InternalModule from compressed_tensors.quantization.quant_args import ( - FP8_E4M3_DATA, QuantizationArgs, - QuantizationStrategy, ) -from compressed_tensors.quantization.utils import is_fp4 +from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin -from loguru import logger -from torch import FloatTensor, IntTensor, Tensor + +from llmcompressor.observers.helpers import flatten_for_calibration __all__ = ["Observer"] @@ -25,287 +24,70 @@ class Observer(InternalModule, RegistryMixin): def __init__( self, - quantization_args: QuantizationArgs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - self.quantization_args: QuantizationArgs = quantization_args super().__init__() - self._scale = None - self._zero_point = None - self._num_observed_tokens = None - - @torch.no_grad() - def forward( - self, - observed: Tensor, - g_idx: Optional[Tensor] = None, - global_scale: Optional[Tensor] = None, - should_calculate_gparam: bool = False, - ) -> Tuple[FloatTensor, IntTensor]: - """ - maps directly to get_qparams - :param observed: optional observed tensor from which to calculate - quantization parameters - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value - """ - self.record_observed_tokens(observed) - if should_calculate_gparam: - return self.get_gparam(observed=observed) - return self.get_qparams( - observed=observed, - g_idx=g_idx, - global_scale=global_scale, - ) + self.module = ref(module) if module is not None else None + self.base_name = base_name + self.args = args - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: optional id for tracking separate statistics when different - ranges of observed tensors are passed, useful for sharding tensors by - group_size or block quantization - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + # populate observer kwargs + self.args.observer_kwargs = self.args.observer_kwargs or {} + self.args.observer_kwargs.update(observer_kwargs) - def calculate_gparam( - self, - observed: Tensor, - ) -> torch.Tensor: - """ - :param observed: observed tensor to calculate quantization parameters for - :return: global scale derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_gparam") + # used for moving averages and testing + self.min_vals = None + self.max_vals = None - def post_calculate_qparams(self) -> None: - """ - Run any logic specific to its observers after running calculate_qparams + @abstractmethod + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ + Calculates updated scales and zero points from observed value + (weight, activation, or attention state). - def get_gparam(self, observed: Tensor): - """ - Function to derive a global scale parameter - :param observed: observed tensor to calculate global parameters - from - :return: derived global scale + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - if self.quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: - return self.calculate_gparam(observed) - raise NotImplementedError( - "global parameter generation is only supported for TENSOR_GROUP" - ) + raise NotImplementedError() - def get_qparams( - self, - observed: Optional[Tensor] = None, - g_idx: Optional[Tensor] = None, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point + Calculates updated scales and zero points from observed value + (weight, activation, or attention state). - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value + :param observed: value being observed + :return: calibrated scale and zero point """ - if observed is not None: - group_size = self.quantization_args.group_size - - if self.quantization_args.strategy == QuantizationStrategy.TENSOR: - # re-calculate scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) - - elif self.quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - rows = observed.shape[0] - columns = observed.shape[1] - num_groups = int(ceil(columns / group_size)) - if num_groups * group_size != columns: - logger.bind(log_once=True).warning( - "Attempting to quantize a module weight whose columns " - f"({columns}) are not divisible by group_size ({group_size}). " - "This scheme is not supported by vLLM, please consider " - "adjusting the group_size for modules with this number of " - "columns", - ) - - self._scale = torch.empty( - (rows, num_groups), dtype=observed.dtype, device=observed.device - ) - if is_fp4(quantization_args=self.quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = self.quantization_args.pytorch_dtype() + g_idx = self._get_module_param("g_idx") + global_scale = self._get_module_param("global_scale") - self._zero_point = torch.empty( - (rows, num_groups), dtype=zp_dtype, device=observed.device - ) + observed = flatten_for_calibration(observed, self.base_name, self.args, g_idx) + self.min_vals, self.max_vals = self.get_min_max(observed) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - - perm = torch.argsort(g_idx) - observed = observed.index_select(dim=1, index=perm) - - # TODO: experiment with vectorizing for loop for performance - end = 0 - for group_index, group_count in enumerate(group_sizes): - start = end - end = start + group_count - scale, zero_point = self.get_qparams_along_dim( - observed[:, start:end], - 0, - tensor_id=group_index, - global_scale=global_scale, - ) - - self._scale[:, group_index] = scale.squeeze(1) - self._zero_point[:, group_index] = zero_point.squeeze(1) - - elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) - - elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - # use dim 1, assume the obsersed.shape = [batch, token, hidden] - # should be batch, token - self._scale, self._zero_point = self.get_qparams_along_dim( - observed, - dim={0, 1}, - ) - - elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # Block-wise quantization: one scale/zero_point per block of shape - # [block_rows, block_cols] - rows, cols = observed.shape[:2] - bs = self.quantization_args.block_structure - if not ( - isinstance(bs, (list, tuple)) - and len(bs) == 2 - and all(isinstance(x, int) for x in bs) - ): - raise ValueError( - f"Invalid block_structure '{bs}'. " - f"Must be a list of two ints [rows, cols]." - ) - block_rows, block_cols = bs - num_br = int(ceil(rows / block_rows)) - num_bc = int(ceil(cols / block_cols)) - - # allocate per-block scale and zero_point - self._scale = torch.empty( - (num_br, num_bc), dtype=observed.dtype, device=observed.device - ) - - # Use same dtype logic as GROUP strategy for zero_point - if is_fp4(quantization_args=self.quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = self.quantization_args.pytorch_dtype() - - self._zero_point = torch.empty( - (num_br, num_bc), dtype=zp_dtype, device=observed.device - ) - - # compute qparams for each block - for i in range(num_br): - r0 = i * block_rows - r1 = min((i + 1) * block_rows, rows) - for j in range(num_bc): - c0 = j * block_cols - c1 = min((j + 1) * block_cols, cols) - # reduce across both dims to get one scale and zp per block - # Use unique tensor_id for each block to maintain separate stats - block_tensor_id = f"block_{i}_{j}" - scale_bp, zp_bp = self.calculate_qparams( - observed[r0:r1, c0:c1], - reduce_dims=(0, 1), - tensor_id=block_tensor_id, - ) - self._scale[i, j] = scale_bp - self._zero_point[i, j] = zp_bp - - return self._scale, self._zero_point - - def get_qparams_along_dim( - self, - observed, - dim: Union[int, Iterable[int]], - tensor_id: Optional[Any] = None, - global_scale: Optional[Tensor] = None, - ): - if isinstance(dim, int): - dim = [dim] - dim = set(dim) - - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, + return calculate_qparams( + min_vals=self.min_vals, + max_vals=self.max_vals, + quantization_args=self.args, global_scale=global_scale, ) - def record_observed_tokens(self, batch_tensor: Tensor): + def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter: """ - Counts the number of tokens observed during the - forward passes. The count is aggregated in the - _num_observed_tokens attribute of the class. + Calculates updated global scale from observed value - Note: The batch_tensor is expected to have two dimensions - (batch_size * sequence_length, num_features). This is the - general shape expected by the forward pass of the expert - layers in a MOE model. If the input tensor does not have - two dimensions, the _num_observed_tokens attribute will be set - to None. + :param observed: value being observed + :return: calibrated global parameter """ - if not isinstance(batch_tensor, Tensor): - raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}") + observed = observed.reshape((1, 1, -1)) # per tensor reshape + min_vals, max_vals = self.get_min_max(observed) + return generate_gparam(min_vals, max_vals) - if batch_tensor.ndim != 2: - logger.debug( - "The input tensor is expected to have two dimensions " - "(batch_size * sequence_length, num_features). " - f"The input tensor has {batch_tensor.ndim} dimensions." - ) - return + def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: + if self.module is None: + return None - if self._num_observed_tokens is None: - # initialize the count - self._num_observed_tokens = 0 - - # batch_tensor (batch_size * sequence_length, num_features) - # observed_tokens (batch_size * sequence_length) - observed_tokens, _ = batch_tensor.shape - self._num_observed_tokens += observed_tokens - - def reset(self): - """ - Reset the state of the observer - """ - self._num_observed_tokens = None - self._scale = None - self._zero_point = None + return getattr(self.module(), f"{self.base_name}_{name}", None) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 5cd32ff645..4560da1b85 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -7,25 +7,125 @@ pruning operations. """ -from collections import Counter +from typing import Optional import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import strategy_cdiv -__all__ = ["get_observer_token_count"] +__all__ = ["flatten_for_calibration"] -def get_observer_token_count(module: torch.nn.Module) -> Counter: +def flatten_for_calibration( + value: torch.Tensor, + base_name: str, + args: QuantizationArgs, + g_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ - Parse the module and return the number of tokens observed by - each module's observer. + Reshapes the value according to the quantization strategy for the purposes of + scale/zp calibration. The value after flattening has the following shape: - :param module: module to parse - :return: counter with the number of tokens observed by each observer + `(num_observations, *qparam_shape, group_size)` + + The first dim is the number of observations (usually the batch size times number of + tokens), the middle dims are the dimension of the scales, and the last dim is the + number of elements being quantized per group. + + :param value: value being flattened + :param base_name: weight, input, output, q/k/v. Used to characterize the value as + being a weight, activation, or attention state + :param args: quantization args for determining how the value is flattened + :param g_idx: optional gidx for weight activation ordering + :return: value which has been reshaped for calibration """ - token_counts = Counter() - for name, module in module.named_modules(): - if name.endswith(".input_observer"): - token_counts[name.replace(".input_observer", "")] = ( - module._num_observed_tokens - ) - return token_counts + if base_name == "weight": + return _flatten_weight(value, args, g_idx) + elif base_name in ("input", "output"): + return _flatten_activation(value, args) + elif base_name in ("q", "k", "v"): + return _flatten_attention(value, args) + else: + raise ValueError(f"Unknown quantization base name: {base_name}") + + +def _flatten_weight( + value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None +): + if args.strategy == QuantizationStrategy.TENSOR: + # (1, 1, num_weight_elems) + return value.reshape((1, 1, -1)) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to weights") + + if args.strategy == QuantizationStrategy.CHANNEL: + # (1, num_rows, 1, num_cols) + return value.unsqueeze(-2).unsqueeze(0) + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + if g_idx is not None: + value = value.index_select(dim=1, index=torch.argsort(g_idx)) + + # (1, num_rows, num_groups, group_size) + return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) + + if args.strategy == QuantizationStrategy.BLOCK: + # (1, num_block_rows, num_block_cols, block_width * block_height) + block_height, block_width = args.block_structure + rows, cols = value.shape + block_rows = strategy_cdiv(rows, block_height, args.strategy, strict=True) + block_cols = strategy_cdiv(cols, block_width, args.strategy, strict=True) + return ( + value.reshape(block_rows, block_height, block_cols, block_width) + .transpose(1, 2) + .flatten(-2, -1) + .unsqueeze(0) + ) + + assert False, f"Unknown strategy {args.strategy}" + + +def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size * seq_len, 1, hidden_dim) + return value.reshape((-1, 1, value.size(-1))) + + if args.strategy == QuantizationStrategy.TOKEN: + # (batch_size, seq_len, hidden_dim) + # warning: token quantization uses `compute_dynamic_scales_and_zp` + return value + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to activations") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (batch_size * seq_len, num_groups, group_size) + # warning: group activation quantization uses compute_dynamic_scales_and_zp + return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to activations") + + assert False, f"Unknown strategy {args.strategy}" + + +def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size, seq_len, num_heads, head_dim) + # (batch_size * seq_len, 1, num_heads * head_dim) + return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to attention") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + raise ValueError("Group quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to attention") + + assert False, f"Unknown strategy {args.strategy}" diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index ce5c0e7790..5dbe8f31e4 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,13 +1,11 @@ -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam -from compressed_tensors.utils import deprecated from llmcompressor.observers.base import Observer -__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"] +__all__ = ["MinMaxObserver"] @Observer.register("minmax") @@ -20,142 +18,39 @@ class MinMaxObserver(Observer): def __init__( self, - quantization_args: QuantizationArgs, - averaging_constant: float = 0.01, - **kwargs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - super().__init__(quantization_args=quantization_args) + super().__init__(base_name, args, module, **observer_kwargs) - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant + observer_kwargs = self.args.observer_kwargs + self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) - def calculate_updated_min_max( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ): + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. + Calculates updated scales and zero points from observed value using the absolute + min and max value. If `averaging_constant` is specified, then subsequent calls + will affect a moving average by the specified constant. - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: updated min and max values + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - tensor_id = tensor_id or "default" - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - # early stopping, save some computation and memory - if self.averaging_constant == 1.0: - return min_val, max_val - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - return updated_min_val, updated_max_val + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) - def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: - """ - Generate a global scale using the observed min and max. - - :param observed: observed tensor to calculate quantization parameters for - :return: updated global scale derived from the observed tensor - """ - - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed - ) - return generate_gparam( - updated_min_val=updated_min_val, updated_max_val=updated_max_val - ) - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Generate a scale and zero-point using the observed min and max. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims - ) - return calculate_qparams( - min_vals=updated_min_val, - max_vals=updated_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - - def get_qparams_along_dim( - self, - observed: torch.Tensor, - dim: int, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ): - """ - Calculate quantization parameters along the specified dimension - """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, - global_scale=global_scale, - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} + if self.min_vals is not None and self.averaging_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) + max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) + return min_vals, max_vals -class MovingAverageMinMaxObserver(MinMaxObserver): - @deprecated( - message=( - "The class name `MovingAverageMinMaxObserver` has been deprecated, please " - "initialize with `MinMaxObserver` in the future" - ) - ) - def __new__(cls, *args, **kwargs): - return super().__new__(MinMaxObserver, *args, **kwargs) + def _lerp( + self, input: torch.Tensor, end: torch.Tensor, weight: float + ) -> torch.Tensor: + """torch lerp_kernel is not implemeneted for all data types""" + return (input * (1.0 - weight)) + (end * weight) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index 419155f077..c33c08d6d2 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -1,9 +1,12 @@ -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.quantization.utils import calculate_qparams -from torch import FloatTensor, IntTensor, Tensor +from compressed_tensors.utils import patch_attr from llmcompressor.observers.base import Observer @@ -19,53 +22,58 @@ class MovingAverageMSEObserver(Observer): def __init__( self, - quantization_args: QuantizationArgs, - maxshrink: float = 0.2, - patience: int = 5, - averaging_constant: float = 0.01, - grid: float = 100.0, - norm: float = 2.4, - **kwargs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - super().__init__(quantization_args=quantization_args) + super().__init__(base_name, args, module, **observer_kwargs) - self.min_val = {} - self.max_val = {} - self.maxshrink = maxshrink - self.patience = patience - self.averaging_constant = averaging_constant - self.grid = grid - self.norm = norm + observer_kwargs = self.args.observer_kwargs + self.maxshrink = observer_kwargs.get("maxshrink", 0.20) + self.patience = observer_kwargs.get("patience", 5) + self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) + self.grid = observer_kwargs.get("grid", 100.0) + self.norm = observer_kwargs.get("norm", 2.4) - def calculate_mse_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - global_scale: Optional[torch.Tensor] = None, - ): + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Computes the mse-clipped min and max values of the observed tensor by - optimizing for quantization error - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned values will be shaped (1,) along the reduced dimensions - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of min and max values derived from the observed tensor + Calculates updated scales and zero points from observed value. Minimum and + maximum values are chosen by grid searching across min/max values which minimize + quantization reconstruction loss. + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - from compressed_tensors.quantization.lifecycle import fake_quantize + min_vals, max_vals = self._mse_min_max(observed) + + if self.min_vals is not None and self.averaging_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) + max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) - if not reduce_dims: - absolute_min_val, absolute_max_val = torch.aminmax(observed) - else: - absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + return min_vals, max_vals + + def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Grid search for MSE-optimal min and max values + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum and maximum values which minimize reconstruction error + """ + from compressed_tensors.quantization.lifecycle import fake_quantize + absolute_min_val = torch.amin(observed, dim=(0, -1)) + absolute_max_val = torch.amax(observed, dim=(0, -1)) best = torch.full_like( absolute_min_val, torch.finfo(absolute_min_val.dtype).max ) min_val = torch.ones_like(absolute_min_val) max_val = torch.zeros_like(absolute_max_val) + global_scale = self._get_module_param("global_scale") # Early stopping params no_improve_count = 0 @@ -78,24 +86,25 @@ def calculate_mse_min_max( candidate_scales, candidate_zero_points = calculate_qparams( min_vals=shrinked_min_val, max_vals=shrinked_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - q = fake_quantize( - observed, - candidate_scales, - candidate_zero_points, - self.quantization_args, + quantization_args=self.args, global_scale=global_scale, ) + # Note that observed.shape = (num_observations, *qparams_shape, group_size). + # For the purposes of fake quantization, this is equivalent to token quant + with patch_attr(self.args, "strategy", QuantizationStrategy.TOKEN): + q = fake_quantize( + observed, + candidate_scales.unsqueeze(-1), + candidate_zero_points.unsqueeze(-1), + self.args, + global_scale=global_scale, + ) + q -= observed q.abs_() q.pow_(self.norm) - if not reduce_dims: - err = torch.sum(q) - else: - err = torch.sum(q, reduce_dims, keepdims=True) + err = torch.sum(q, dim=(0, -1)) tmp = err < best if torch.any(tmp): @@ -110,104 +119,8 @@ def calculate_mse_min_max( return min_val, max_val - def calculate_updated_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: updated min and max values derived from the observed value - """ - # TODO: will need to be expanded to support fp4 activations; - # currently not supported - min_val, max_val = self.calculate_mse_min_max( - observed, reduce_dims, global_scale=global_scale - ) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - tensor_id = tensor_id or "default" - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - return updated_min_val, updated_max_val - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed, - tensor_id=tensor_id, - reduce_dims=reduce_dims, - global_scale=global_scale, - ) - scale, zero_point = calculate_qparams( - min_vals=updated_min_val, - max_vals=updated_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - return scale, zero_point - - def get_qparams_along_dim( - self, - observed, - dim: int, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, - global_scale=global_scale, - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} + def _lerp( + self, input: torch.Tensor, end: torch.Tensor, weight: float + ) -> torch.Tensor: + """torch lerp_kernel is not implemeneted for all data types""" + return (input * (1.0 - weight)) + (end * weight) diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py index 9b03234cf7..70f0e61259 100644 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_cache.py @@ -29,7 +29,7 @@ def test_is_quantized_cache_singleton(): args = QuantizationArgs() cache = QuantizedKVParameterCache(args) observer = args.observer - observer = Observer.load_from_registry(observer, quantization_args=args) + observer = Observer.load_from_registry(observer, base_name="k", args=args) tensor = torch.tensor([1, 2, 3]) cache.k_scales.append(tensor) diff --git a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py new file mode 100644 index 0000000000..dae4054636 --- /dev/null +++ b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py @@ -0,0 +1,337 @@ +import pytest +import torch +from compressed_tensors.quantization import ( + QuantizationScheme, + forward_quantize, + initialize_module_for_quantization, + initialize_qparams, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus + +from llmcompressor.modifiers.quantization.calibration import initialize_observer + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + ), + torch.tensor([0.0]), + torch.tensor([23.0]), + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + # token is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + ), + torch.tensor([[0], [6], [12], [18]]), + torch.tensor([[5], [11], [17], [23]]), + torch.tensor( + [ + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", # requires float4 + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], + ], + ), + 1.1, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[2, 3], + ), + torch.tensor([[0, 3], [12, 15]]), + torch.tensor([[8, 11], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.5, + ), + ], +) +def test_static_weight_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + weight = tensor([[ 0, 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23]]) + """ + # set up weight + input_size, output_size = 6, 4 + linear = torch.nn.Linear(input_size, output_size, bias=False) + linear.weight.data = torch.arange( + input_size * output_size, dtype=torch.bfloat16 + ).reshape(output_size, input_size) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + initialize_observer(linear, "weight") + + # calibrate_global_scale + if hasattr(linear, "weight_global_scale"): + global_scale = linear.weight_observer.get_global_scale(linear.weight) + linear.weight_global_scale.data = global_scale + + # calibrate quantization parameters + scale, zero_point = linear.weight_observer(linear.weight) + linear.weight_scale.data = scale + linear.weight_zero_point.data = zero_point + assert torch.equal(linear.weight_observer.min_vals, exp_min_val) + assert torch.equal(linear.weight_observer.max_vals, exp_max_val) + + # forward pass + input = torch.eye(input_size, dtype=torch.bfloat16) + output = linear(input) + + assert torch.allclose(output.T, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + ] + ] + ), + 0.2, + ), + # static token is not supported + # channel is not supported + # group is not supported + ( + QuantizationArgs( + num_bits=4, + type="float", # must be fp4 + symmetric=True, + strategy="tensor_group", + dynamic="local", + group_size=3, + ), + None, + None, + torch.tensor( + [ + [ + [0.0000, 0.9844, 1.9688, 3.4062, 3.4062, 5.1250], + [5.2500, 7.8750, 7.8750, 7.3438, 11.0000, 11.0000], + ] + ] + ), + 0.5, + ), + # block is not supported + # head is not supported + ], +) +def test_static_activation_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[ 0, 1, 2, 3, 4, 5] + [ 6, 7, 8, 9, 10, 11]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, input_size = 1, 2, 6 + input = torch.arange( + (batch_size * seq_len * input_size), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, input_size)) + linear = torch.nn.Linear(input_size, input_size, bias=False) + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + initialize_observer(linear, "input") + + # calibrate quantization parameters + def calibrate_input_hook(_, args): + if hasattr(linear, "input_global_scale"): + global_scale = linear.input_observer.get_global_scale(args[0]) + linear.input_global_scale.data = global_scale + + if linear.quantization_scheme.input_activations.dynamic is False: + scale, zero_point = linear.input_observer(args[0]) + linear.input_scale.data = scale + linear.input_zero_point.data = zero_point + + linear.register_forward_pre_hook(calibrate_input_hook) + + # calibration forward pass + output = linear(input) + + # check calibration + if exp_min_val is not None: + assert torch.equal(linear.input_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(linear.input_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss + + +class MockAttention(torch.nn.Module): + pass + + +@pytest.mark.filterwarnings("ignore::UserWarning") # cpu offloading for MockAttention +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], + [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.19, + ), + # static token is not supported + # channel is not supported + # group is not supported + # tensor group is not supported + # block is not supported + ], +) +def test_static_attention_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[[[ 0., 1., 2.], + [ 3., 4., 5.]], + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + input = torch.arange( + (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, num_heads, head_dim)) + attention = MockAttention() + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_qparams( + attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + ) + attention.quantization_scheme = scheme + attention.quantization_status = QuantizationStatus.INITIALIZED + initialize_observer(attention, "k") + + # calibrate quantization parameters + if scheme.input_activations.dynamic is False: + scale, zero_point = attention.k_observer(input) + attention.k_scale.data = scale + attention.k_zero_point.data = zero_point + + # calibration forward pass + output = forward_quantize(attention, input, "k", scheme.input_activations) + + # check calibration + if exp_min_val is not None: + assert torch.equal(attention.k_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(attention.k_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index a742a48b21..57f4de40bf 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -13,17 +13,17 @@ "shape,group_size,actorder", [ ((1, 1), None, False), - ((1, 1), 128, False), - ((1, 1), 128, True), + ((1, 1), 1, False), + ((1, 1), 1, True), ((64, 64), None, False), - ((64, 64), 128, False), - ((64, 64), 128, True), - ((1792, 4096), None, False), - ((1792, 4096), 128, False), - ((1792, 4096), 128, True), - ((3420, 64), None, False), - ((3420, 64), 128, False), - ((3420, 64), 128, True), + ((64, 64), 32, False), + ((64, 64), 32, True), + ((896, 4096), None, False), + ((896, 4096), 7, False), + ((896, 4096), 7, True), + ((512, 64), None, False), + ((512, 64), 128, False), + ((512, 64), 128, True), ], ) def test_observers_update(shape, group_size, actorder): @@ -49,8 +49,7 @@ def test_observers_update(shape, group_size, actorder): ("output", output), ): observer = getattr(module, f"{location}_observer") - g_idx = getattr(module, "g_idx", None) - updated_scale, updated_zero_point = observer(value, g_idx=g_idx) + updated_scale, updated_zero_point = observer(value) assert_alike(updated_scale, getattr(module, f"{location}_scale")) assert_alike(updated_zero_point, getattr(module, f"{location}_zero_point")) diff --git a/tests/llmcompressor/observers/test_helpers.py b/tests/llmcompressor/observers/test_helpers.py index 527176019b..5b1909828c 100644 --- a/tests/llmcompressor/observers/test_helpers.py +++ b/tests/llmcompressor/observers/test_helpers.py @@ -12,98 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, + QuantizationArgs, + QuantizationScheme, + initialize_module_for_quantization, ) -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_input_hook, - initialize_observer, -) -from llmcompressor.observers.helpers import get_observer_token_count - - -def _prep_for_input_quant_calibration(module: torch.nn.Module): - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - return - - module.register_forward_pre_hook(calibrate_input_hook) - module.quantization_status = QuantizationStatus.CALIBRATION +from llmcompressor.observers.helpers import flatten_for_calibration -def test_get_observer_token_count(): - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - model.eval() - config = QuantizationConfig( - format="fakequant", - quantization_status="calibration", - config_groups={ - "group_1": { - "input_activations": { - "num_bits": 8, - "type": "int", - "symmetric": False, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, - ) - apply_quantization_config(model, config) - model.apply(lambda module: initialize_observer(module, base_name="input")) - model.apply(_prep_for_input_quant_calibration) - - # start calibration - calib_list = [ - "I am a string that", - "is used for calibration so", - "that your model is", - "quantized properly.", - ] - total_num_tokens_observed = 0 - for calib_sample in calib_list: - calib_tensor = tokenizer(calib_sample, return_tensors="pt") - _ = model(**calib_tensor) - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) +def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: + perm = torch.randperm(columns) + return torch.tensor([index // group_size for index in range(columns)])[perm] - counter = get_observer_token_count(model) - # filter out the None values - # (tokens, in the appropriate format, that were not observed by the model) - counter = {k: v for k, v in counter.items() if v is not None} +@pytest.mark.parametrize( + "args", + [ + QuantizationArgs(strategy="tensor"), + QuantizationArgs(strategy="tensor_group", group_size=4), + ], +) +def test_flatten_for_calibration_input(args): + module = torch.nn.Linear(8, 10) + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(module, scheme) - # iterate over all the layers in the model where the token count in the proper - # format is has been observed - for i in range(model.config.num_hidden_layers): - # fetch the tokens observed by the router - tokens_observed_by_router = counter.pop( - f"model.layers.{i}.block_sparse_moe.gate" - ) - assert tokens_observed_by_router == total_num_tokens_observed + input = torch.empty((3, 5, 8)) + input_flattened = flatten_for_calibration(input, "input", scheme.input_activations) + assert input_flattened.shape[1:-1] == module.input_scale.shape + assert input_flattened.shape[1:-1] == module.input_zero_point.shape - # fetch the sum of tokens observed by all the experts - sum_tokens_observed_by_experts = 0 - keys_for_this_layer = [ - k - for k in counter.keys() - if f"model.layers.{i}.block_sparse_moe.experts" in k - ] - for key in keys_for_this_layer: - sum_tokens_observed_by_experts += counter.pop(key) - # each Mixtral expert is comprised of 3 linear layers, - # so we need to multiply by 3 - assert ( - sum_tokens_observed_by_experts - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 - ) +@pytest.mark.parametrize( + "args,g_idx", + [ + (QuantizationArgs(strategy="tensor"), None), + (QuantizationArgs(strategy="channel"), None), + (QuantizationArgs(strategy="group", group_size=4), None), + (QuantizationArgs(strategy="group", group_size=4), make_dummy_g_idx(8, 4)), + (QuantizationArgs(strategy="tensor_group", group_size=4), None), + (QuantizationArgs(strategy="block", block_structure=[5, 4]), None), + ], +) +def test_flatten_for_calibration_weights(args, g_idx): + module = torch.nn.Linear(8, 10) + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(module, scheme) - # there are no more information in the counter - assert len(counter) == 0 + weight_flattened = flatten_for_calibration( + module.weight, + "weight", + scheme.weights, + g_idx=g_idx, + ) + assert weight_flattened.shape[1:-1] == module.weight_scale.shape + assert weight_flattened.shape[1:-1] == module.weight_zero_point.shape diff --git a/tests/llmcompressor/observers/test_min_max.py b/tests/llmcompressor/observers/test_min_max.py index 229c51ca7c..8edc0d8e5b 100644 --- a/tests/llmcompressor/observers/test_min_max.py +++ b/tests/llmcompressor/observers/test_min_max.py @@ -41,7 +41,7 @@ def test_min_max_observer(symmetric, expected_scale, expected_zero_point): ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) assert round(scale.item(), 4) == expected_scale @@ -56,7 +56,7 @@ def test_min_max_observer_symmetric_scale_range(): weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="minmax") observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) # if symmetric, max symmetric_range = abs(-128) / 255 @@ -82,15 +82,17 @@ def test_min_max_observer_value_update(): tensor = inp num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="minmax") + weights = QuantizationArgs( + num_bits=num_bits, strategy="tensor", symmetric=True, observer="minmax" + ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) curr_max = 1 curr_min = 1 for i, tensor in enumerate(tensors): observer(tensor) - curr_max = max(observer.max_val.get("default"), curr_max) - curr_min = min(observer.min_val.get("default"), curr_max) + curr_max = max(observer.max_vals[0], curr_max) + curr_min = min(observer.min_vals[0], curr_min) if i < 2: assert curr_max == 1 @@ -108,13 +110,20 @@ def test_g_idx(): input_shape = (128, 512) tensor = torch.rand(input_shape) weights = QuantizationArgs(num_bits=8, group_size=group_size, observer="minmax") + + module = torch.nn.Linear(512, 1) g_idx = make_dummy_g_idx(tensor.shape[1], group_size) + module.weight_g_idx = g_idx - observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) - scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) + observer = Observer.load_from_registry( + weights.observer, base_name="weight", args=weights, module=module + ) + scale_g_idx, zero_point_g_idx = observer(tensor) - observer.reset() + observer = Observer.load_from_registry( + weights.observer, base_name="weight", args=weights, module=module + ) + del module.weight_g_idx scale, zero_point = observer(tensor[:, torch.argsort(g_idx)]) assert scale_g_idx == pytest.approx(scale) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 1ba79495f1..f741d42490 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -15,30 +15,45 @@ import pytest import torch +from compressed_tensors.quantization import fake_quantize from compressed_tensors.quantization.quant_args import QuantizationArgs from llmcompressor.observers import MovingAverageMSEObserver, Observer @pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", + "strategy,symmetric,exp_loss", [ - (True, 0.0078, 0), - (False, 0.0039, -128), + ("tensor", True, 4.8103e-06), + ("tensor", False, 1.1258e-06), + ("channel", True, 2.5675e-06), + ("channel", False, 2.3696e-07), + ("group", True, 3.1282e-06), + ("group", False, 1.3794e-07), + ("block", True, 2.8968e-06), + ("block", False, 5.6068e-07), ], ) -def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) +def test_mse_observer(strategy, symmetric, exp_loss): + tensor = torch.arange(24).reshape((6, 4)) / 24 num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") + weights = QuantizationArgs( + num_bits=num_bits, + strategy=strategy, + symmetric=symmetric, + group_size=(2 if strategy == "group" else None), + block_structure=([3, 2] if strategy == "block" else None), + observer="mse", + ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) - scale, zero_point = observer(tensor) - + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) assert isinstance(observer, MovingAverageMSEObserver) - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point + + scale, zero_point = observer(tensor) + q_tensor = fake_quantize(tensor, scale, zero_point, weights) + mse_loss = torch.sum((tensor - q_tensor).abs_().pow_(2)) / tensor.numel() + assert mse_loss == pytest.approx(exp_loss, abs=1e-10) def test_mse_observer_symmetric_scale_range(): @@ -49,7 +64,7 @@ def test_mse_observer_symmetric_scale_range(): weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="mse") observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) # if symmetric, max symmetric_range = abs(-128) / 255 From 1c2d550a40c9e761dc803d3ac339d28c5d323656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 18:29:37 -0400 Subject: [PATCH 02/23] add torch inductor ignore Signed-off-by: Kyle Sayers --- tests/llmcompressor/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llmcompressor/conftest.py b/tests/llmcompressor/conftest.py index 04fa589281..c0d976e5ba 100644 --- a/tests/llmcompressor/conftest.py +++ b/tests/llmcompressor/conftest.py @@ -48,7 +48,7 @@ def _files_size_mb(path_list: List[str]) -> int: @pytest.fixture(scope="session", autouse=True) def check_for_created_files(): - ignore_dirs = ["__pycache__", "sparse_logs"] + ignore_dirs = ["__pycache__", "sparse_logs", "torchinductor"] start_files_root = _get_files(directory=r".", ignore_dirs=ignore_dirs) start_files_temp = _get_files( directory=tempfile.gettempdir(), ignore_dirs=["pytest-of"] From 32879da12259c74cdc2aeb8a907d701ca805e563 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 8 Oct 2025 16:20:06 -0400 Subject: [PATCH 03/23] ignore inductor, add fp4 test Signed-off-by: Kyle Sayers --- tests/llmcompressor/conftest.py | 11 ++++++----- tests/llmcompressor/observers/test_mse.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/llmcompressor/conftest.py b/tests/llmcompressor/conftest.py index c0d976e5ba..46dd950bac 100644 --- a/tests/llmcompressor/conftest.py +++ b/tests/llmcompressor/conftest.py @@ -48,10 +48,11 @@ def _files_size_mb(path_list: List[str]) -> int: @pytest.fixture(scope="session", autouse=True) def check_for_created_files(): - ignore_dirs = ["__pycache__", "sparse_logs", "torchinductor"] - start_files_root = _get_files(directory=r".", ignore_dirs=ignore_dirs) + local_ignore_dirs = ["__pycache__", "sparse_logs"] + tmp_ignore_dirs = ["pytest-of", "torchinductor"] + start_files_root = _get_files(directory=r".", ignore_dirs=local_ignore_dirs) start_files_temp = _get_files( - directory=tempfile.gettempdir(), ignore_dirs=["pytest-of"] + directory=tempfile.gettempdir(), ignore_dirs=tmp_ignore_dirs ) yield if wandb: @@ -61,7 +62,7 @@ def check_for_created_files(): shutil.rmtree(log_dir) # allow creation of __pycache__ directories - end_files_root = _get_files(directory=r".", ignore_dirs=ignore_dirs) + end_files_root = _get_files(directory=r".", ignore_dirs=local_ignore_dirs) # assert no files created in root directory while running # the pytest suite assert len(start_files_root) >= len(end_files_root), ( @@ -74,7 +75,7 @@ def check_for_created_files(): max_allowed_sized_temp_files_megabytes = 1 # pytest temp files are automatically deleted, exclude from size calculation end_files_temp = _get_files( - directory=tempfile.gettempdir(), ignore_dirs=["pytest-of"] + directory=tempfile.gettempdir(), ignore_dirs=tmp_ignore_dirs ) created_temp_files = set(end_files_temp) - set(start_files_temp) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index f741d42490..121f871819 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -70,3 +70,22 @@ def test_mse_observer_symmetric_scale_range(): # if symmetric, max symmetric_range = abs(-128) / 255 assert round(scale.item(), 4) <= 1.0039 assert round(zero_point.item(), 4) == 0 + + +def test_mse_fp4(): + tensor = torch.arange(24, dtype=torch.bfloat16).reshape((4, 6)) / 24 + + weights = QuantizationArgs( + num_bits=4, + type="float", # must be fp4 + symmetric=True, + strategy="tensor_group", + group_size=3, + ) + + observer = weights.observer + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) + scale, zero_point = observer(tensor) + + qdq_tensor = fake_quantize(tensor, scale, zero_point, weights) + assert torch.nn.functional.mse_loss(qdq_tensor, tensor) <= 0.002 From a0b83b4ad2c6f14cb2fd2cc305331c9ad360b5a5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 8 Oct 2025 16:34:23 -0400 Subject: [PATCH 04/23] add fp4 test Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/base.py | 9 ++++++--- tests/llmcompressor/observers/test_mse.py | 18 +++++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 4af5c37e39..5d69297531 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -9,6 +9,7 @@ ) from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.utils import patch_attr from llmcompressor.observers.helpers import flatten_for_calibration @@ -82,9 +83,11 @@ def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter: :param observed: value being observed :return: calibrated global parameter """ - observed = observed.reshape((1, 1, -1)) # per tensor reshape - min_vals, max_vals = self.get_min_max(observed) - return generate_gparam(min_vals, max_vals) + # avoid updating running min/max for global scales + with patch_attr(self, "min_vals", None), patch_attr(self, "max_vals", None): + observed = observed.reshape((1, 1, -1)) # per tensor reshape + min_vals, max_vals = self.get_min_max(observed) + return generate_gparam(min_vals, max_vals) def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: if self.module is None: diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 121f871819..8d7f9c2cb8 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -73,7 +73,8 @@ def test_mse_observer_symmetric_scale_range(): def test_mse_fp4(): - tensor = torch.arange(24, dtype=torch.bfloat16).reshape((4, 6)) / 24 + module = torch.nn.Linear(6, 4) + module.weight.data = torch.arange(24, dtype=torch.bfloat16).reshape((4, 6)) / 24 weights = QuantizationArgs( num_bits=4, @@ -84,8 +85,15 @@ def test_mse_fp4(): ) observer = weights.observer - observer = Observer.load_from_registry(observer, base_name="weight", args=weights) - scale, zero_point = observer(tensor) + observer = Observer.load_from_registry( + observer, base_name="weight", args=weights, module=module + ) - qdq_tensor = fake_quantize(tensor, scale, zero_point, weights) - assert torch.nn.functional.mse_loss(qdq_tensor, tensor) <= 0.002 + global_scale = observer.get_global_scale(module.weight) + module.weight_global_scale = global_scale + scale, zero_point = observer(module.weight) + + qdq_tensor = fake_quantize( + module.weight, scale, zero_point, weights, global_scale=global_scale + ) + assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.002 From 96b3995ce9fe8bf3ff1362c28678143bbddd0d32 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 9 Oct 2025 13:39:35 -0400 Subject: [PATCH 05/23] fix gptq observer call Signed-off-by: Kyle Sayers --- .../modifiers/quantization/gptq/gptq_quantize.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 28926650fe..21c6888eda 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -121,22 +121,23 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - scale, zero_point = observer(W, g_idx=None) + module.weight_g_idx = g_idx + scale, zero_point = observer(W) # use identity g_idx (invert permutation later) elif actorder == ActivationOrdering.WEIGHT: # update groups first, then permute by activation order - scale, zero_point = observer(W, g_idx=None) + scale, zero_point = observer(W) W, H, perm = _apply_activation_ordering(W, H) # permute g_idx to maintain identity mapping after unpermutation g_idx = g_idx[perm] else: - scale, zero_point = observer(W, g_idx=None) + scale, zero_point = observer(W) else: - scale, zero_point = observer(W, g_idx=None) + scale, zero_point = observer(W) # sparsity mask sparsity = tensor_sparsity(W) From 292b131934fdbe06121931145142ce896c516281 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 10 Oct 2025 01:43:46 -0400 Subject: [PATCH 06/23] abstraction Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/__init__.py | 2 + src/llmcompressor/observers/base.py | 31 ++++--- src/llmcompressor/observers/min_max.py | 63 ++++++--------- src/llmcompressor/observers/moving_base.py | 81 +++++++++++++++++++ src/llmcompressor/observers/mse.py | 78 ++++++++---------- src/llmcompressor/observers/static_base.py | 59 ++++++++++++++ .../modifiers/calibration/test_observers.py | 59 ++++++++++++++ tests/llmcompressor/observers/test_mse.py | 16 +++- 8 files changed, 292 insertions(+), 97 deletions(-) create mode 100644 src/llmcompressor/observers/moving_base.py create mode 100644 src/llmcompressor/observers/static_base.py diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index a019d33c2e..4b6d5707f2 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -11,5 +11,7 @@ from .helpers import * from .base import * +from .moving_base import * +from .static_base import * from .min_max import * from .mse import * diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 5d69297531..bdbec7db32 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -9,7 +9,6 @@ ) from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import patch_attr from llmcompressor.observers.helpers import flatten_for_calibration @@ -42,12 +41,13 @@ def __init__( # used for moving averages and testing self.min_vals = None self.max_vals = None + self.global_min_vals = None + self.global_max_vals = None @abstractmethod def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Calculates updated scales and zero points from observed value - (weight, activation, or attention state). + Calculate min and max values from observed value :param observed: value being observed whose shape is (num_observations, *qparam_shape, group_size) @@ -55,9 +55,23 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso """ raise NotImplementedError() + @abstractmethod + def get_global_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate min and max values from observed value for the purposes of + global scale calculation + + :param observed: value being observed whose shape is + (num_observations, 1, group_size) + :return: minimum value and maximum value whose shapes are (1, ) + """ + raise NotImplementedError() + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Calculates updated scales and zero points from observed value + Calculate updated scales and zero points from observed value (weight, activation, or attention state). :param observed: value being observed @@ -78,16 +92,15 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter: """ - Calculates updated global scale from observed value + Calculate updated global scale from observed value :param observed: value being observed :return: calibrated global parameter """ # avoid updating running min/max for global scales - with patch_attr(self, "min_vals", None), patch_attr(self, "max_vals", None): - observed = observed.reshape((1, 1, -1)) # per tensor reshape - min_vals, max_vals = self.get_min_max(observed) - return generate_gparam(min_vals, max_vals) + observed = observed.reshape((1, 1, -1)) # per tensor reshape + self.global_min_vals, self.global_max_vals = self.get_global_min_max(observed) + return generate_gparam(self.global_min_vals, self.global_max_vals) def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: if self.module is None: diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index 5dbe8f31e4..2744a9ca72 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,56 +1,39 @@ -from typing import Optional, Tuple +from typing import Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs from llmcompressor.observers.base import Observer +from llmcompressor.observers.moving_base import MovingAverageObserverBase +from llmcompressor.observers.static_base import StaticObserverBase -__all__ = ["MinMaxObserver"] +__all__ = ["StaticMinMaxObserver", "MinMaxObserver"] -@Observer.register("minmax") -class MinMaxObserver(Observer): +@Observer.register("static_minmax") +class StaticMinMaxObserver(StaticObserverBase): """ Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average + the minimum and maximum values of all observed values """ - def __init__( - self, - base_name: str, - args: QuantizationArgs, - module: Optional[torch.nn.Module] = None, - **observer_kwargs, - ): - super().__init__(base_name, args, module, **observer_kwargs) - - observer_kwargs = self.args.observer_kwargs - self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) - - def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculates updated scales and zero points from observed value using the absolute - min and max value. If `averaging_constant` is specified, then subsequent calls - will affect a moving average by the specified constant. - - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) - :return: minimum value and maximum value whose shapes are (*qparam_shape, ) - """ + def get_current_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: min_vals = torch.amin(observed, dim=(0, -1)) max_vals = torch.amax(observed, dim=(0, -1)) - if self.min_vals is not None and self.averaging_constant != 1.0: - # FUTURE: consider scaling by num observations (first dim) - # rather than reducing by first dim - min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) - max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) - return min_vals, max_vals - def _lerp( - self, input: torch.Tensor, end: torch.Tensor, weight: float - ) -> torch.Tensor: - """torch lerp_kernel is not implemeneted for all data types""" - return (input * (1.0 - weight)) + (end * weight) + +@Observer.register("minmax") +class MinMaxObserver(MovingAverageObserverBase): + """ + Implements a quantization observer that calculates scale and zero point based on the + minimum and maximum values of the tensor being observed. If averaging_constant is + specified, then the scales are updated using a moving average + """ + + def get_current_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return StaticMinMaxObserver.get_current_min_max(self, observed) diff --git a/src/llmcompressor/observers/moving_base.py b/src/llmcompressor/observers/moving_base.py new file mode 100644 index 0000000000..dd48077910 --- /dev/null +++ b/src/llmcompressor/observers/moving_base.py @@ -0,0 +1,81 @@ +from abc import abstractmethod +from typing import Optional, Tuple + +import torch +from compressed_tensors.quantization.quant_args import QuantizationArgs + +from llmcompressor.observers.base import Observer + +__all__ = ["MovingAverageObserverBase"] + + +class MovingAverageObserverBase(Observer): + """ + Implements a quantization observer that calculates scale and zero point based on the + minimum and maximum values of the tensor being observed. If averaging_constant is + specified, then the scales are updated using a moving average + """ + + def __init__( + self, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, + ): + super().__init__(base_name, args, module, **observer_kwargs) + self.avg_constant = self.args.observer_kwargs.get("averaging_constant", 0.01) + + @abstractmethod + def get_current_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the min and max value of the observed value (without moving average) + """ + raise NotImplementedError() + + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate moving average of min and max values from observed value + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) + """ + min_vals, max_vals = self.get_current_min_max(observed) + + if self.min_vals is not None and self.avg_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.min_vals, min_vals, self.avg_constant) + max_vals = self._lerp(self.max_vals, max_vals, self.avg_constant) + + return min_vals, max_vals + + def get_global_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate moving average of min and max values from observed value for the + purposes of global scale calculation + + :param observed: value being observed whose shape is + (num_observations, 1, group_size) + :return: minimum value and maximum value whose shapes are (1, ) + """ + min_vals, max_vals = self.get_current_min_max(observed) + + if self.global_min_vals is not None and self.avg_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.global_min_vals, min_vals, self.avg_constant) + max_vals = self._lerp(self.global_max_vals, max_vals, self.avg_constant) + + return min_vals, max_vals + + def _lerp( + self, input: torch.Tensor, end: torch.Tensor, weight: float + ) -> torch.Tensor: + """torch lerp_kernel is not implemeneted for all data types""" + return (input * (1.0 - weight)) + (end * weight) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index c33c08d6d2..b9b1d4b408 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -1,62 +1,38 @@ -from typing import Optional, Tuple +from typing import Tuple import torch +from compressed_tensors.quantization.lifecycle import fake_quantize from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, QuantizationStrategy, ) from compressed_tensors.quantization.utils import calculate_qparams from compressed_tensors.utils import patch_attr from llmcompressor.observers.base import Observer +from llmcompressor.observers.moving_base import MovingAverageObserverBase +from llmcompressor.observers.static_base import StaticObserverBase -__all__ = ["MovingAverageMSEObserver"] +__all__ = ["StaticMSEObserver", "MovingAverageMSEObserver"] -@Observer.register("mse") -class MovingAverageMSEObserver(Observer): +@Observer.register("static_mse") +class StaticMSEObserver(StaticObserverBase): """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the mse-clipped min and max observed values + TODO """ - def __init__( - self, - base_name: str, - args: QuantizationArgs, - module: Optional[torch.nn.Module] = None, - **observer_kwargs, - ): - super().__init__(base_name, args, module, **observer_kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) observer_kwargs = self.args.observer_kwargs self.maxshrink = observer_kwargs.get("maxshrink", 0.20) self.patience = observer_kwargs.get("patience", 5) - self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) self.grid = observer_kwargs.get("grid", 100.0) self.norm = observer_kwargs.get("norm", 2.4) - def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculates updated scales and zero points from observed value. Minimum and - maximum values are chosen by grid searching across min/max values which minimize - quantization reconstruction loss. - - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) - :return: minimum value and maximum value whose shapes are (*qparam_shape, ) - """ - min_vals, max_vals = self._mse_min_max(observed) - - if self.min_vals is not None and self.averaging_constant != 1.0: - # FUTURE: consider scaling by num observations (first dim) - # rather than reducing by first dim - min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) - max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) - - return min_vals, max_vals - - def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def get_current_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Grid search for MSE-optimal min and max values @@ -64,8 +40,6 @@ def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens (num_observations, *qparam_shape, group_size) :return: minimum and maximum values which minimize reconstruction error """ - from compressed_tensors.quantization.lifecycle import fake_quantize - absolute_min_val = torch.amin(observed, dim=(0, -1)) absolute_max_val = torch.amax(observed, dim=(0, -1)) best = torch.full_like( @@ -99,7 +73,7 @@ def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens candidate_zero_points.unsqueeze(-1), self.args, global_scale=global_scale, - ) + ).to(observed.dtype) q -= observed q.abs_() @@ -119,8 +93,22 @@ def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return min_val, max_val - def _lerp( - self, input: torch.Tensor, end: torch.Tensor, weight: float - ) -> torch.Tensor: - """torch lerp_kernel is not implemeneted for all data types""" - return (input * (1.0 - weight)) + (end * weight) + +@Observer.register("mse") +class MovingAverageMSEObserver(MovingAverageObserverBase): + """ + Implements a dynamic quantization observer that sets the scale and + zero point based on a moving average of the mse-clipped min and max observed values + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + observer_kwargs = self.args.observer_kwargs + self.maxshrink = observer_kwargs.get("maxshrink", 0.20) + self.patience = observer_kwargs.get("patience", 5) + self.grid = observer_kwargs.get("grid", 100.0) + self.norm = observer_kwargs.get("norm", 2.4) + + def get_current_min_max(self, observed): + return StaticMSEObserver.get_current_min_max(self, observed) diff --git a/src/llmcompressor/observers/static_base.py b/src/llmcompressor/observers/static_base.py new file mode 100644 index 0000000000..58e7237ab2 --- /dev/null +++ b/src/llmcompressor/observers/static_base.py @@ -0,0 +1,59 @@ +from abc import abstractmethod +from typing import Tuple + +import torch + +from llmcompressor.observers.base import Observer + +__all__ = ["StaticObserverBase"] + + +class StaticObserverBase(Observer): + """ + Implements a quantization observer that calculates scale and zero point based on the + minimum and maximum values of all observed values + """ + + @abstractmethod + def get_current_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the min and max value of the observed value + """ + raise NotImplementedError() + + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate min and max values from all observed values + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) + """ + min_vals, max_vals = self.get_current_min_max(observed) + + if self.min_vals is not None: + min_vals = torch.min(min_vals, self.min_vals) + max_vals = torch.max(max_vals, self.max_vals) + + return min_vals, max_vals + + def get_global_min_max( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate min and max values from all observed values for the purposes of global + scale calculation + + :param observed: value being observed whose shape is + (num_observations, 1, group_size) + :return: minimum value and maximum value whose shapes are (1, ) + """ + min_vals, max_vals = self.get_current_min_max(observed) + + if self.global_min_vals is not None: + min_vals = torch.min(min_vals, self.global_min_vals) + max_vals = torch.max(max_vals, self.global_max_vals) + + return min_vals, max_vals diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index 57f4de40bf..0c7d550ff5 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -7,6 +7,7 @@ ) from llmcompressor.modifiers.quantization.calibration import initialize_observer +from llmcompressor.observers import Observer @pytest.mark.parametrize( @@ -58,3 +59,61 @@ def test_observers_update(shape, group_size, actorder): def assert_alike(a, b): assert a.dtype == b.dtype assert a.shape == b.shape + + +@pytest.mark.parametrize("is_global", [False, True]) +@pytest.mark.parametrize( + "name,kwargs,observed,exp_min_vals,exp_max_vals", + ( + ( + "static_minmax", + {}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-3.0], [-3.0]]), + torch.tensor([[0.0], [1.0], [3.0]]), + ), + ( + "static_mse", + {}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-3.0], [-3.0]]), + torch.tensor([[0.0], [1.0], [3.0]]), + ), + ( + "minmax", # moving average + {"averaging_constant": 0.1}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-0.3], [-0.37]]), + torch.tensor([[0.0], [0.1], [0.39]]), + ), + ( + "mse", # moving average + {"averaging_constant": 0.1}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-0.3], [-0.37]]), + torch.tensor([[0.0], [0.1], [0.39]]), + ), + ), +) +def test_observer_moving_static( + name, kwargs, observed, exp_min_vals, exp_max_vals, is_global +): + observer = Observer.load_from_registry( + name, base_name="input", args=QuantizationArgs(strategy="tensor"), **kwargs + ) + + min_vals, max_vals = [], [] + for _observed in observed: + if not is_global: + observer(_observed) + min_vals.append(observer.min_vals) + max_vals.append(observer.max_vals) + else: + observer.get_global_scale(_observed) + min_vals.append(observer.global_min_vals) + max_vals.append(observer.global_max_vals) + + min_vals = torch.stack(min_vals) + max_vals = torch.stack(max_vals) + assert torch.allclose(min_vals, exp_min_vals) + assert torch.allclose(max_vals, exp_max_vals) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 8d7f9c2cb8..618840b438 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -84,9 +84,8 @@ def test_mse_fp4(): group_size=3, ) - observer = weights.observer observer = Observer.load_from_registry( - observer, base_name="weight", args=weights, module=module + "mse", base_name="weight", args=weights, module=module ) global_scale = observer.get_global_scale(module.weight) @@ -96,4 +95,15 @@ def test_mse_fp4(): qdq_tensor = fake_quantize( module.weight, scale, zero_point, weights, global_scale=global_scale ) - assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.002 + assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.0015 # 0.0013 + + # sanity check: scales calibrated without global scales are worse + observer = Observer.load_from_registry( + "mse", base_name="weight", args=weights, module=module + ) + global_scale = observer.get_global_scale(module.weight) + scale, zero_point = observer(module.weight) # no global scale + qdq_tensor = fake_quantize( + module.weight, scale, zero_point, weights, global_scale=global_scale + ) + assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) >= 0.0035 # 0.0036 From de6f30267b45e70d46d2be9b3d45e548f3455704 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 10 Oct 2025 01:47:03 -0400 Subject: [PATCH 07/23] comments Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/mse.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index b9b1d4b408..ed1fc67a37 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -74,6 +74,8 @@ def get_current_min_max( self.args, global_scale=global_scale, ).to(observed.dtype) + # Note that due to forward quantization implementation, token quant, + # unlike tensor_group, requires extra dtype cast q -= observed q.abs_() From 599912ee2c83ed32e1fa64a744b46d99e4f824f0 Mon Sep 17 00:00:00 2001 From: dhuangnm <74931910+dhuangnm@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:52:03 -0400 Subject: [PATCH 08/23] Pick up compressed-tensors 0.12.2 for patch release (#1904) SUMMARY: Pick up compressed-tensors 0.12.2 for patch release 0.8.1 TEST PLAN: All tests Signed-off-by: Dan Huang --- .github/workflows/test-check-transformers.yaml | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-check-transformers.yaml b/.github/workflows/test-check-transformers.yaml index 0dd8876655..4ffde0b5e2 100644 --- a/.github/workflows/test-check-transformers.yaml +++ b/.github/workflows/test-check-transformers.yaml @@ -60,7 +60,7 @@ jobs: steps: - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.10' - uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/setup.py b/setup.py index a178bdf03a..00538afbe6 100644 --- a/setup.py +++ b/setup.py @@ -140,9 +140,9 @@ def localversion_func(version: ScmVersion) -> str: ), ("pillow>=10.4.0,<=11.3.0" if BUILD_TYPE == "release" else "pillow>=10.4.0"), ( - "compressed-tensors==0.12.1" + "compressed-tensors==0.12.2" if BUILD_TYPE == "release" - else "compressed-tensors>=0.12.2a2" + else "compressed-tensors>=0.12.3a2" ), ], extras_require={ From 21083a0c2559d3cdeed0168ff314a2b8dff337ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 10 Oct 2025 14:33:39 -0400 Subject: [PATCH 09/23] use offload utils Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 21c6888eda..f632e3029a 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -10,6 +10,7 @@ QuantizationStrategy, fake_quantize, ) +from compressed_tensors.utils import update_offload_parameter from loguru import logger from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD @@ -121,7 +122,7 @@ def quantize_weight( if actorder == ActivationOrdering.GROUP: # permute by activation order first, then update groups W, H, perm = _apply_activation_ordering(W, H) - module.weight_g_idx = g_idx + update_offload_parameter(module, "weight_g_idx", g_idx) scale, zero_point = observer(W) # use identity g_idx (invert permutation later) From e722e20a573781995bbf63d24144b7981ef66653 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 10:08:52 -0400 Subject: [PATCH 10/23] small cleanup Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/mse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index ed1fc67a37..f70961caeb 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -23,7 +23,6 @@ class StaticMSEObserver(StaticObserverBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - observer_kwargs = self.args.observer_kwargs self.maxshrink = observer_kwargs.get("maxshrink", 0.20) self.patience = observer_kwargs.get("patience", 5) From 5edfe0e3037a102138a5e2536b1d101a7d9437ee Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 14:21:39 -0400 Subject: [PATCH 11/23] update tests Signed-off-by: Kyle Sayers --- .../quantization/gptq/gptq_quantize.py | 2 +- src/llmcompressor/observers/__init__.py | 1 - src/llmcompressor/observers/base.py | 84 ++++--- src/llmcompressor/observers/min_max.py | 78 +++++-- src/llmcompressor/observers/moving_base.py | 54 +++-- src/llmcompressor/observers/mse.py | 208 +++++++++++------- src/llmcompressor/observers/static_base.py | 59 ----- .../modifiers/calibration/test_observers.py | 18 +- tests/llmcompressor/observers/test_min_max.py | 6 +- 9 files changed, 283 insertions(+), 227 deletions(-) delete mode 100644 src/llmcompressor/observers/static_base.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index f632e3029a..9464caff5e 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -96,7 +96,7 @@ def quantize_weight( # create observer for calculating quantization parameters observer = Observer.load_from_registry( - "minmax", + "memoryless_minmax", base_name="weight", args=quant_args, module=module, diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index 4b6d5707f2..c2fd7de53f 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -12,6 +12,5 @@ from .helpers import * from .base import * from .moving_base import * -from .static_base import * from .min_max import * from .mse import * diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index bdbec7db32..91fd1092e6 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -4,15 +4,16 @@ import torch from compressed_tensors import InternalModule -from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, -) +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin from llmcompressor.observers.helpers import flatten_for_calibration -__all__ = ["Observer"] +__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"] + +MinMaxTuple = Tuple[torch.Tensor, torch.Tensor] +ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor] class Observer(InternalModule, RegistryMixin): @@ -38,38 +39,28 @@ def __init__( self.args.observer_kwargs = self.args.observer_kwargs or {} self.args.observer_kwargs.update(observer_kwargs) - # used for moving averages and testing - self.min_vals = None - self.max_vals = None - self.global_min_vals = None - self.global_max_vals = None - @abstractmethod - def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate min and max values from observed value - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) + :param observed: value of shape (num_observations, *qparam_shape, group_size) :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ raise NotImplementedError() @abstractmethod - def get_global_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate min and max values from observed value for the purposes of global scale calculation - :param observed: value being observed whose shape is - (num_observations, 1, group_size) + :param observed: value of shape (num_observations, 1, group_size) :return: minimum value and maximum value whose shapes are (1, ) """ raise NotImplementedError() - def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, observed: torch.Tensor) -> ScaleZpTuple: """ Calculate updated scales and zero points from observed value (weight, activation, or attention state). @@ -77,33 +68,60 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: :param observed: value being observed :return: calibrated scale and zero point """ + scales, zero_points, _min, _max = self._forward_with_minmax(observed) + return (scales, zero_points) + + def get_global_scale(self, observed: torch.Tensor) -> torch.Tensor: + """ + Calculate updated global scale from observed value + (weight, activation, or attention state). + + :param observed: value being observed + :return: calibrated global parameter + """ + global_scale, _min, _max = self._get_global_scale_with_minmax(observed) + return global_scale + + def _forward_with_minmax( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: g_idx = self._get_module_param("g_idx") global_scale = self._get_module_param("global_scale") + self._check_has_global_scale(global_scale) observed = flatten_for_calibration(observed, self.base_name, self.args, g_idx) - self.min_vals, self.max_vals = self.get_min_max(observed) + min_vals, max_vals = self.get_min_max(observed) - return calculate_qparams( - min_vals=self.min_vals, - max_vals=self.max_vals, + scales, zero_points = calculate_qparams( + min_vals=min_vals, + max_vals=max_vals, quantization_args=self.args, global_scale=global_scale, ) + return scales, zero_points, min_vals, max_vals - def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter: - """ - Calculate updated global scale from observed value - - :param observed: value being observed - :return: calibrated global parameter - """ - # avoid updating running min/max for global scales + def _get_global_scale_with_minmax( + self, observed: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: observed = observed.reshape((1, 1, -1)) # per tensor reshape - self.global_min_vals, self.global_max_vals = self.get_global_min_max(observed) - return generate_gparam(self.global_min_vals, self.global_max_vals) + + global_min_vals, global_max_vals = self.get_global_min_max(observed) + global_scale = generate_gparam(global_min_vals, global_max_vals) + + return global_scale, global_min_vals, global_max_vals def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: if self.module is None: return None return getattr(self.module(), f"{self.base_name}_{name}", None) + + def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): + if ( + self.args.strategy == QuantizationStrategy.TENSOR_GROUP + and global_scale is None + ): + raise ValueError( + "Cannot compute scale and zero points " + "without first computing global scale" + ) diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index 2744a9ca72..c4254608c1 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,26 +1,58 @@ -from typing import Tuple - import torch -from llmcompressor.observers.base import Observer +from llmcompressor.observers.base import MinMaxTuple, Observer from llmcompressor.observers.moving_base import MovingAverageObserverBase -from llmcompressor.observers.static_base import StaticObserverBase -__all__ = ["StaticMinMaxObserver", "MinMaxObserver"] +__all__ = ["MemorylessMinMaxObserver", "StaticMinMaxObserver", "MinMaxObserver"] + + +@Observer.register("memoryless_minmax") +class MemorylessMinMaxObserver(Observer): + """ + TODO + """ + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _get_min_max(observed) + + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _get_min_max(observed) @Observer.register("static_minmax") -class StaticMinMaxObserver(StaticObserverBase): +class StaticMinMaxObserver(Observer): """ - Implements a quantization observer that calculates scale and zero point based on the - the minimum and maximum values of all observed values + TODO """ - def get_current_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - min_vals = torch.amin(observed, dim=(0, -1)) - max_vals = torch.amax(observed, dim=(0, -1)) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.past_min_vals = None + self.past_max_vals = None + self.past_global_min_vals = None + self.past_global_max_vals = None + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + min_vals, max_vals = _get_min_max(observed) + + if self.past_min_vals is not None: + min_vals = torch.min(min_vals, self.past_min_vals) + max_vals = torch.max(max_vals, self.past_max_vals) + + self.past_min_vals = min_vals + self.past_max_vals = max_vals + + return min_vals, max_vals + + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + min_vals, max_vals = _get_min_max(observed) + + if self.past_global_min_vals is not None: + min_vals = torch.min(min_vals, self.past_global_min_vals) + max_vals = torch.max(max_vals, self.past_global_max_vals) + + self.past_global_min_vals = min_vals + self.past_global_max_vals = max_vals return min_vals, max_vals @@ -28,12 +60,18 @@ def get_current_min_max( @Observer.register("minmax") class MinMaxObserver(MovingAverageObserverBase): """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average + TODO """ - def get_current_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - return StaticMinMaxObserver.get_current_min_max(self, observed) + def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _get_min_max(observed) + + def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _get_min_max(observed) + + +def _get_min_max(observed: torch.Tensor) -> MinMaxTuple: + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) + + return min_vals, max_vals diff --git a/src/llmcompressor/observers/moving_base.py b/src/llmcompressor/observers/moving_base.py index dd48077910..d8acf0c1b0 100644 --- a/src/llmcompressor/observers/moving_base.py +++ b/src/llmcompressor/observers/moving_base.py @@ -1,19 +1,17 @@ from abc import abstractmethod -from typing import Optional, Tuple +from typing import Optional import torch from compressed_tensors.quantization.quant_args import QuantizationArgs -from llmcompressor.observers.base import Observer +from llmcompressor.observers.base import MinMaxTuple, Observer __all__ = ["MovingAverageObserverBase"] class MovingAverageObserverBase(Observer): """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average + TODO """ def __init__( @@ -26,16 +24,26 @@ def __init__( super().__init__(base_name, args, module, **observer_kwargs) self.avg_constant = self.args.observer_kwargs.get("averaging_constant", 0.01) + self.past_min_vals = None + self.past_max_vals = None + self.past_global_min_vals = None + self.past_global_max_vals = None + @abstractmethod - def get_current_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate the min and max value of the observed value (without moving average) """ raise NotImplementedError() - def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + @abstractmethod + def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + """ + Calculate the min and max value of the observed value (without moving average) + """ + raise NotImplementedError() + + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate moving average of min and max values from observed value @@ -45,17 +53,18 @@ def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso """ min_vals, max_vals = self.get_current_min_max(observed) - if self.min_vals is not None and self.avg_constant != 1.0: + if self.past_min_vals is not None and self.avg_constant != 1.0: # FUTURE: consider scaling by num observations (first dim) # rather than reducing by first dim - min_vals = self._lerp(self.min_vals, min_vals, self.avg_constant) - max_vals = self._lerp(self.max_vals, max_vals, self.avg_constant) + min_vals = self._lerp(self.past_min_vals, min_vals, self.avg_constant) + max_vals = self._lerp(self.past_max_vals, max_vals, self.avg_constant) + + self.past_min_vals = min_vals + self.past_max_vals = max_vals return min_vals, max_vals - def get_global_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate moving average of min and max values from observed value for the purposes of global scale calculation @@ -64,13 +73,20 @@ def get_global_min_max( (num_observations, 1, group_size) :return: minimum value and maximum value whose shapes are (1, ) """ - min_vals, max_vals = self.get_current_min_max(observed) + min_vals, max_vals = self.get_current_global_min_max(observed) - if self.global_min_vals is not None and self.avg_constant != 1.0: + if self.past_global_min_vals is not None and self.avg_constant != 1.0: # FUTURE: consider scaling by num observations (first dim) # rather than reducing by first dim - min_vals = self._lerp(self.global_min_vals, min_vals, self.avg_constant) - max_vals = self._lerp(self.global_max_vals, max_vals, self.avg_constant) + min_vals = self._lerp( + self.past_global_min_vals, min_vals, self.avg_constant + ) + max_vals = self._lerp( + self.past_global_max_vals, max_vals, self.avg_constant + ) + + self.past_global_min_vals = min_vals + self.past_global_max_vals = max_vals return min_vals, max_vals diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index f70961caeb..6a5e31974a 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -1,26 +1,22 @@ -from typing import Tuple +from typing import Optional import torch -from compressed_tensors.quantization.lifecycle import fake_quantize -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( + QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.quantization.lifecycle import fake_quantize +from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.utils import patch_attr -from llmcompressor.observers.base import Observer +from llmcompressor.observers.base import MinMaxTuple, Observer from llmcompressor.observers.moving_base import MovingAverageObserverBase -from llmcompressor.observers.static_base import StaticObserverBase - -__all__ = ["StaticMSEObserver", "MovingAverageMSEObserver"] +__all__ = ["MovingAverageMSEObserver"] -@Observer.register("static_mse") -class StaticMSEObserver(StaticObserverBase): - """ - TODO - """ +@Observer.register("memoryless_mse") +class MemorylessMSEObserver(Observer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) observer_kwargs = self.args.observer_kwargs @@ -29,70 +25,30 @@ def __init__(self, *args, **kwargs): self.grid = observer_kwargs.get("grid", 100.0) self.norm = observer_kwargs.get("norm", 2.4) - def get_current_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Grid search for MSE-optimal min and max values - - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) - :return: minimum and maximum values which minimize reconstruction error - """ - absolute_min_val = torch.amin(observed, dim=(0, -1)) - absolute_max_val = torch.amax(observed, dim=(0, -1)) - best = torch.full_like( - absolute_min_val, torch.finfo(absolute_min_val.dtype).max - ) - min_val = torch.ones_like(absolute_min_val) - max_val = torch.zeros_like(absolute_max_val) + def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: global_scale = self._get_module_param("global_scale") + return _grid_search_mse( + observed, + self.args, + self.maxshrink, + self.patience, + self.grid, + self.norm, + global_scale=global_scale, + optimize_global_scale=False, + ) - # Early stopping params - no_improve_count = 0 - - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - shrinked_min_val = p * absolute_min_val - shrinked_max_val = p * absolute_max_val - - candidate_scales, candidate_zero_points = calculate_qparams( - min_vals=shrinked_min_val, - max_vals=shrinked_max_val, - quantization_args=self.args, - global_scale=global_scale, - ) - - # Note that observed.shape = (num_observations, *qparams_shape, group_size). - # For the purposes of fake quantization, this is equivalent to token quant - with patch_attr(self.args, "strategy", QuantizationStrategy.TOKEN): - q = fake_quantize( - observed, - candidate_scales.unsqueeze(-1), - candidate_zero_points.unsqueeze(-1), - self.args, - global_scale=global_scale, - ).to(observed.dtype) - # Note that due to forward quantization implementation, token quant, - # unlike tensor_group, requires extra dtype cast - - q -= observed - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, dim=(0, -1)) - - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - min_val[tmp] = shrinked_min_val[tmp] - max_val[tmp] = shrinked_max_val[tmp] - no_improve_count = 0 - else: - no_improve_count += 1 - if no_improve_count >= self.patience: - break - - return min_val, max_val + def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _grid_search_mse( + observed, + self.args, + self.maxshrink, + self.patience, + self.grid, + self.norm, + global_scale=None, + optimize_global_scale=True, + ) @Observer.register("mse") @@ -104,12 +60,108 @@ class MovingAverageMSEObserver(MovingAverageObserverBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - observer_kwargs = self.args.observer_kwargs self.maxshrink = observer_kwargs.get("maxshrink", 0.20) self.patience = observer_kwargs.get("patience", 5) self.grid = observer_kwargs.get("grid", 100.0) self.norm = observer_kwargs.get("norm", 2.4) - def get_current_min_max(self, observed): - return StaticMSEObserver.get_current_min_max(self, observed) + def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + global_scale = self._get_module_param("global_scale") + return _grid_search_mse( + observed, + self.args, + self.maxshrink, + self.patience, + self.grid, + self.norm, + global_scale=global_scale, + optimize_global_scale=False, + ) + + def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + return _grid_search_mse( + observed, + self.args, + self.maxshrink, + self.patience, + self.grid, + self.norm, + global_scale=None, + optimize_global_scale=True, + ) + + +def _grid_search_mse( + observed: torch.Tensor, + args: QuantizationArgs, + maxshrink: float, + patience: float, + grid: float, + norm: float, + global_scale: Optional[torch.Tensor] = None, + optimize_global_scale: bool = False, +) -> MinMaxTuple: + """ + Grid search for MSE-optimal min and max values. If global scale is not given, + then + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum and maximum values which minimize reconstruction error + """ + absolute_min_val = torch.amin(observed, dim=(0, -1)) + absolute_max_val = torch.amax(observed, dim=(0, -1)) + best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max) + min_val = torch.ones_like(absolute_min_val) + max_val = torch.zeros_like(absolute_max_val) + + # Early stopping params + no_improve_count = 0 + + # @ksayers @HGCharles: investigate searching over separate shrinking factors + for i in range(int(maxshrink * grid)): + p = 1 - i / grid + shrinked_min_val = p * absolute_min_val + shrinked_max_val = p * absolute_max_val + + if optimize_global_scale: + global_scale = generate_gparam(shrinked_min_val, shrinked_max_val) + + candidate_scales, candidate_zero_points = calculate_qparams( + min_vals=shrinked_min_val, + max_vals=shrinked_max_val, + quantization_args=args, + global_scale=global_scale, + ) + + # Note that observed.shape = (num_observations, *qparams_shape, group_size). + # For the purposes of fake quantization, this is equivalent to token quant + with patch_attr(args, "strategy", QuantizationStrategy.TOKEN): + q = fake_quantize( + observed, + candidate_scales.unsqueeze(-1), + candidate_zero_points.unsqueeze(-1), + args, + global_scale=global_scale, + ).to(observed.dtype) + # Note that due to forward quantization implementation, token quant, + # unlike tensor_group, requires extra dtype cast + + q -= observed + q.abs_() + q.pow_(norm) + err = torch.sum(q, dim=(0, -1)) + + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + min_val[tmp] = shrinked_min_val[tmp] + max_val[tmp] = shrinked_max_val[tmp] + no_improve_count = 0 + else: + no_improve_count += 1 + if no_improve_count >= patience: + break + + return min_val, max_val diff --git a/src/llmcompressor/observers/static_base.py b/src/llmcompressor/observers/static_base.py deleted file mode 100644 index 58e7237ab2..0000000000 --- a/src/llmcompressor/observers/static_base.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import abstractmethod -from typing import Tuple - -import torch - -from llmcompressor.observers.base import Observer - -__all__ = ["StaticObserverBase"] - - -class StaticObserverBase(Observer): - """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of all observed values - """ - - @abstractmethod - def get_current_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculate the min and max value of the observed value - """ - raise NotImplementedError() - - def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculate min and max values from all observed values - - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) - :return: minimum value and maximum value whose shapes are (*qparam_shape, ) - """ - min_vals, max_vals = self.get_current_min_max(observed) - - if self.min_vals is not None: - min_vals = torch.min(min_vals, self.min_vals) - max_vals = torch.max(max_vals, self.max_vals) - - return min_vals, max_vals - - def get_global_min_max( - self, observed: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculate min and max values from all observed values for the purposes of global - scale calculation - - :param observed: value being observed whose shape is - (num_observations, 1, group_size) - :return: minimum value and maximum value whose shapes are (1, ) - """ - min_vals, max_vals = self.get_current_min_max(observed) - - if self.global_min_vals is not None: - min_vals = torch.min(min_vals, self.global_min_vals) - max_vals = torch.max(max_vals, self.global_max_vals) - - return min_vals, max_vals diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index 0c7d550ff5..ee44f0510d 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -72,13 +72,6 @@ def assert_alike(a, b): torch.tensor([[0.0], [-3.0], [-3.0]]), torch.tensor([[0.0], [1.0], [3.0]]), ), - ( - "static_mse", - {}, - torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), - torch.tensor([[0.0], [-3.0], [-3.0]]), - torch.tensor([[0.0], [1.0], [3.0]]), - ), ( "minmax", # moving average {"averaging_constant": 0.1}, @@ -105,13 +98,12 @@ def test_observer_moving_static( min_vals, max_vals = [], [] for _observed in observed: if not is_global: - observer(_observed) - min_vals.append(observer.min_vals) - max_vals.append(observer.max_vals) + _, _, _min_vals, _max_vals = observer._forward_with_minmax(_observed) else: - observer.get_global_scale(_observed) - min_vals.append(observer.global_min_vals) - max_vals.append(observer.global_max_vals) + _, _min_vals, _max_vals = observer._get_global_scale_with_minmax(_observed) + + min_vals.append(_min_vals) + max_vals.append(_max_vals) min_vals = torch.stack(min_vals) max_vals = torch.stack(max_vals) diff --git a/tests/llmcompressor/observers/test_min_max.py b/tests/llmcompressor/observers/test_min_max.py index 8edc0d8e5b..645f9a1dcd 100644 --- a/tests/llmcompressor/observers/test_min_max.py +++ b/tests/llmcompressor/observers/test_min_max.py @@ -90,9 +90,9 @@ def test_min_max_observer_value_update(): curr_max = 1 curr_min = 1 for i, tensor in enumerate(tensors): - observer(tensor) - curr_max = max(observer.max_vals[0], curr_max) - curr_min = min(observer.min_vals[0], curr_min) + _, _, min_vals, max_vals = observer._forward_with_minmax(tensor) + curr_max = max(max_vals[0], curr_max) + curr_min = min(min_vals[0], curr_min) if i < 2: assert curr_max == 1 From 79b4c335b8c5ffd4e364e18584251e216b901e67 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 15:19:58 -0400 Subject: [PATCH 12/23] update test, slightly change mse Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/base.py | 25 +++- src/llmcompressor/observers/min_max.py | 24 +++- src/llmcompressor/observers/moving_base.py | 13 ++- src/llmcompressor/observers/mse.py | 109 ++++++++++++++---- .../modifiers/calibration/test_observers.py | 16 ++- tests/llmcompressor/observers/test_mse.py | 17 +-- 6 files changed, 161 insertions(+), 43 deletions(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 91fd1092e6..7e52ca6be6 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -7,6 +7,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.utils import align_module_device from llmcompressor.observers.helpers import flatten_for_calibration @@ -18,9 +19,22 @@ class Observer(InternalModule, RegistryMixin): """ - Base Observer class to be subclassed for specific implementation. - Subclasses should override `calculate_qparams` to return a scale, zero_point - pair + Base class for observers which compute quantization parameters given observerations + of weights, activations, or attention states. + + Example: + ```python + module = ... + observer = Observer.load_from_registry(observer, base_name="weight", args=...) + module.global_scale = observer.get_global_scale(module.weight) + scales, zero_points = observer(module.weight) + ``` + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization """ def __init__( @@ -111,10 +125,11 @@ def _get_global_scale_with_minmax( return global_scale, global_min_vals, global_max_vals def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: - if self.module is None: + if self.module is None or (module := self.module()) is None: return None - return getattr(self.module(), f"{self.base_name}_{name}", None) + with align_module_device(module): + return getattr(module, f"{self.base_name}_{name}", None) def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): if ( diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index c4254608c1..3888ed67b2 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -9,7 +9,13 @@ @Observer.register("memoryless_minmax") class MemorylessMinMaxObserver(Observer): """ - TODO + Compute quantization parameters by taking the min/max of the observed value + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization """ def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: @@ -22,7 +28,13 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: @Observer.register("static_minmax") class StaticMinMaxObserver(Observer): """ - TODO + Compute quantization parameters by taking the min/max of all observed values + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization """ def __init__(self, *args, **kwargs): @@ -60,7 +72,13 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: @Observer.register("minmax") class MinMaxObserver(MovingAverageObserverBase): """ - TODO + Compute quantization parameters by taking the moving average of all min/max values + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization """ def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: diff --git a/src/llmcompressor/observers/moving_base.py b/src/llmcompressor/observers/moving_base.py index d8acf0c1b0..f94c474284 100644 --- a/src/llmcompressor/observers/moving_base.py +++ b/src/llmcompressor/observers/moving_base.py @@ -11,7 +11,13 @@ class MovingAverageObserverBase(Observer): """ - TODO + Compute quantization parameters by taking the moving average of min/max values + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization """ def __init__( @@ -40,6 +46,7 @@ def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ Calculate the min and max value of the observed value (without moving average) + for the purposes of global scale calculation """ raise NotImplementedError() @@ -66,8 +73,8 @@ def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ - Calculate moving average of min and max values from observed value for the - purposes of global scale calculation + Calculate moving average of min and max values from observed value + for the purposes of global scale calculation :param observed: value being observed whose shape is (num_observations, 1, group_size) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index 6a5e31974a..f21c675ab6 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -17,6 +17,34 @@ @Observer.register("memoryless_mse") class MemorylessMSEObserver(Observer): + """ + Compute quantization parameters by finding the optimal min/max values which minimize + the mean of quantization error squared + + ```psuedocode + mse_quant_error := mean((x - fake_quant(x))**2) + global_scale <- min[min_vals, max_vals, global_scale](mse_quant_error(x)) + scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale)) + ``` + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization\n + maxshrink: maximum shrink amount (in “grid steps”). The number of + search steps is int(maxshrink * grid)\n + patience: number of consecutive search steps without improvement before + early stopping\n + grid: resolution of the shrink search. Larger values give finer granularity + in shrink factors\n + norm: exponent used when computing the error. norm = 2 approximates MSE\n + global_scale: precomputed global scale to use for quantization. Ignored if + `optimize_global_scale` is True\n + optimize_global_scale: If True, recompute ``global_scale`` from the + candidate min/max during each step of the search + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) observer_kwargs = self.args.observer_kwargs @@ -26,6 +54,7 @@ def __init__(self, *args, **kwargs): self.norm = observer_kwargs.get("norm", 2.4) def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + # min[min_vals, max_vals](mse_quant_error) global_scale = self._get_module_param("global_scale") return _grid_search_mse( observed, @@ -39,6 +68,7 @@ def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple: ) def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + # min[min_vals, max_vals, global_scale](mse_quant_error) return _grid_search_mse( observed, self.args, @@ -54,8 +84,31 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: @Observer.register("mse") class MovingAverageMSEObserver(MovingAverageObserverBase): """ - Implements a dynamic quantization observer that sets the scale and - zero point based on a moving average of the mse-clipped min and max observed values + Compute quantization parameters by finding the optimal min/max values which minimize + the mean of quantization error squared. + + ```psuedocode + mse_quant_error := mean((x - fake_quant(x))**2) + global_scale <- min[min_vals, max_vals, global_scale](mse_quant_error(x)) + scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale)) + ``` + + :param base_name: str used to name the observer attribute + :param args: quantization args used to calibrate and quantize the observed value + :param module: optional module with attached quantization parameters. This argument + is required to utilize existing qparams such as global_scale or g_idx + :param **observer_kwargs: keyword arguments for observer initialization\n + maxshrink: maximum shrink amount (in “grid steps”). The number of + search steps is int(maxshrink * grid)\n + patience: number of consecutive search steps without improvement before + early stopping\n + grid: resolution of the shrink search. Larger values give finer granularity + in shrink factors\n + norm: exponent used when computing the error. norm = 2 approximates MSE\n + global_scale: precomputed global scale to use for quantization. Ignored if + `optimize_global_scale` is True\n + optimize_global_scale: If True, recompute ``global_scale`` from the + candidate min/max during each step of the search """ def __init__(self, *args, **kwargs): @@ -67,6 +120,7 @@ def __init__(self, *args, **kwargs): self.norm = observer_kwargs.get("norm", 2.4) def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + # min[min_vals, max_vals](mse_quant_error) global_scale = self._get_module_param("global_scale") return _grid_search_mse( observed, @@ -80,6 +134,7 @@ def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple: ) def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: + # min[min_vals, max_vals, global_scale](mse_quant_error) return _grid_search_mse( observed, self.args, @@ -103,18 +158,32 @@ def _grid_search_mse( optimize_global_scale: bool = False, ) -> MinMaxTuple: """ - Grid search for MSE-optimal min and max values. If global scale is not given, - then - - :param observed: value being observed whose shape is - (num_observations, *qparam_shape, group_size) - :return: minimum and maximum values which minimize reconstruction error + Perform a 1-D grid search to find per-channel min/max ranges that minimize + mean-squared quantization error. + + This routine progressively “shrinks” the absolute min/max ranges of the + observed tensor and evaluates the quantization error at each candidate + range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``. + + :param observed: value of shape (num_observations, *qparams_shape, group_size) + :param args: quantization args used for computing qparams and fake quant + :param maxshrink: maximum shrink amount (in “grid steps”). The number of + search steps is int(maxshrink * grid) + :param patience: number of consecutive search steps without improvement before + early stopping + :param grid: resolution of the shrink search. Larger values give finer granularity + in shrink factors + :param norm: exponent used when computing the error. norm = 2 approximates MSE + :param global_scale: precomputed global scale to use for quantization. Ignored if + `optimize_global_scale` is True + :param optimize_global_scale: If True, recompute ``global_scale`` from the + candidate min/max during each step of the search """ - absolute_min_val = torch.amin(observed, dim=(0, -1)) - absolute_max_val = torch.amax(observed, dim=(0, -1)) - best = torch.full_like(absolute_min_val, torch.finfo(absolute_min_val.dtype).max) - min_val = torch.ones_like(absolute_min_val) - max_val = torch.zeros_like(absolute_max_val) + min_val = torch.amin(observed, dim=(0, -1)) + max_val = torch.amax(observed, dim=(0, -1)) + best_error = torch.full_like(min_val, torch.finfo(min_val.dtype).max) + best_min_val = min_val.clone() + best_max_val = max_val.clone() # Early stopping params no_improve_count = 0 @@ -122,8 +191,8 @@ def _grid_search_mse( # @ksayers @HGCharles: investigate searching over separate shrinking factors for i in range(int(maxshrink * grid)): p = 1 - i / grid - shrinked_min_val = p * absolute_min_val - shrinked_max_val = p * absolute_max_val + shrinked_min_val = p * min_val + shrinked_max_val = p * max_val if optimize_global_scale: global_scale = generate_gparam(shrinked_min_val, shrinked_max_val) @@ -153,15 +222,15 @@ def _grid_search_mse( q.pow_(norm) err = torch.sum(q, dim=(0, -1)) - tmp = err < best + tmp = err < best_error if torch.any(tmp): - best[tmp] = err[tmp] - min_val[tmp] = shrinked_min_val[tmp] - max_val[tmp] = shrinked_max_val[tmp] + best_error[tmp] = err[tmp] + best_min_val[tmp] = shrinked_min_val[tmp] + best_max_val[tmp] = shrinked_max_val[tmp] no_improve_count = 0 else: no_improve_count += 1 if no_improve_count >= patience: break - return min_val, max_val + return best_min_val, best_max_val diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index ee44f0510d..e6bdad0be1 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -65,6 +65,13 @@ def assert_alike(a, b): @pytest.mark.parametrize( "name,kwargs,observed,exp_min_vals,exp_max_vals", ( + ( + "memoryless_minmax", + {}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-3.0], [-1.0]]), + torch.tensor([[0.0], [1.0], [3.0]]), + ), ( "static_minmax", {}, @@ -79,6 +86,13 @@ def assert_alike(a, b): torch.tensor([[0.0], [-0.3], [-0.37]]), torch.tensor([[0.0], [0.1], [0.39]]), ), + ( + "memoryless_mse", + {}, + torch.tensor([[0.0, 0.0], [-3.0, 1.0], [-1.0, 3.0]]), + torch.tensor([[0.0], [-3.0], [-1.0]]), + torch.tensor([[0.0], [1.0], [3.0]]), + ), ( "mse", # moving average {"averaging_constant": 0.1}, @@ -88,7 +102,7 @@ def assert_alike(a, b): ), ), ) -def test_observer_moving_static( +def test_observer_min_max_vals( name, kwargs, observed, exp_min_vals, exp_max_vals, is_global ): observer = Observer.load_from_registry( diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 618840b438..b1c0b4e702 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -88,22 +88,17 @@ def test_mse_fp4(): "mse", base_name="weight", args=weights, module=module ) + # must compute global scale first + with pytest.raises(ValueError): + scale, zero_point = observer(module.weight) + + # compute qparams global_scale = observer.get_global_scale(module.weight) module.weight_global_scale = global_scale scale, zero_point = observer(module.weight) + # check mse loss qdq_tensor = fake_quantize( module.weight, scale, zero_point, weights, global_scale=global_scale ) assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.0015 # 0.0013 - - # sanity check: scales calibrated without global scales are worse - observer = Observer.load_from_registry( - "mse", base_name="weight", args=weights, module=module - ) - global_scale = observer.get_global_scale(module.weight) - scale, zero_point = observer(module.weight) # no global scale - qdq_tensor = fake_quantize( - module.weight, scale, zero_point, weights, global_scale=global_scale - ) - assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) >= 0.0035 # 0.0036 From c44f2136e3b0f4afc06b0f506fb820fd23c2b2b0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 15:21:52 -0400 Subject: [PATCH 13/23] save gptq compute Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 9464caff5e..0f21cd6345 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -100,7 +100,6 @@ def quantize_weight( base_name="weight", args=quant_args, module=module, - averaging_constant=1.0, # ignore moving average ) # standardize shape and dtype From c8a00d14f3dc7933cbdcf0f5d42c3ba8fcac6b5b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 15:36:50 -0400 Subject: [PATCH 14/23] skip gradient calculations to save memory Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 7e52ca6be6..384bbf6ead 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -74,6 +74,7 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ raise NotImplementedError() + @torch.no_grad def forward(self, observed: torch.Tensor) -> ScaleZpTuple: """ Calculate updated scales and zero points from observed value @@ -85,6 +86,7 @@ def forward(self, observed: torch.Tensor) -> ScaleZpTuple: scales, zero_points, _min, _max = self._forward_with_minmax(observed) return (scales, zero_points) + @torch.no_grad def get_global_scale(self, observed: torch.Tensor) -> torch.Tensor: """ Calculate updated global scale from observed value From f023f5f947fa1e7f9209ef4e8b900c7f959e97cb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 15:41:40 -0400 Subject: [PATCH 15/23] fix lifecycle tests Signed-off-by: Kyle Sayers --- .../modifiers/calibration/test_lifecycle.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py index dae4054636..0878975f67 100644 --- a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py +++ b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py @@ -146,11 +146,13 @@ def test_static_weight_quantization( linear.weight_global_scale.data = global_scale # calibrate quantization parameters - scale, zero_point = linear.weight_observer(linear.weight) + scale, zero_point, min_vals, max_vals = linear.weight_observer._forward_with_minmax( + linear.weight + ) linear.weight_scale.data = scale linear.weight_zero_point.data = zero_point - assert torch.equal(linear.weight_observer.min_vals, exp_min_val) - assert torch.equal(linear.weight_observer.max_vals, exp_max_val) + assert torch.equal(min_vals, exp_min_val) + assert torch.equal(max_vals, exp_max_val) # forward pass input = torch.eye(input_size, dtype=torch.bfloat16) @@ -231,14 +233,21 @@ def test_static_activation_quantization( assert getattr(linear, "quantization_scheme") is scheme initialize_observer(linear, "input") + min_vals, max_vals = None, None + # calibrate quantization parameters def calibrate_input_hook(_, args): + nonlocal min_vals + nonlocal max_vals + if hasattr(linear, "input_global_scale"): global_scale = linear.input_observer.get_global_scale(args[0]) linear.input_global_scale.data = global_scale if linear.quantization_scheme.input_activations.dynamic is False: - scale, zero_point = linear.input_observer(args[0]) + scale, zero_point, min_vals, max_vals = ( + linear.input_observer._forward_with_minmax(args[0]) + ) linear.input_scale.data = scale linear.input_zero_point.data = zero_point @@ -249,9 +258,9 @@ def calibrate_input_hook(_, args): # check calibration if exp_min_val is not None: - assert torch.equal(linear.input_observer.min_vals, exp_min_val) + assert torch.equal(min_vals, exp_min_val) if exp_max_val is not None: - assert torch.equal(linear.input_observer.max_vals, exp_max_val) + assert torch.equal(max_vals, exp_max_val) # check forward pass assert torch.allclose(output, exp_quant.to(output.dtype)) @@ -319,7 +328,9 @@ def test_static_attention_quantization( # calibrate quantization parameters if scheme.input_activations.dynamic is False: - scale, zero_point = attention.k_observer(input) + scale, zero_point, min_vals, max_vals = ( + attention.k_observer._forward_with_minmax(input) + ) attention.k_scale.data = scale attention.k_zero_point.data = zero_point @@ -328,9 +339,9 @@ def test_static_attention_quantization( # check calibration if exp_min_val is not None: - assert torch.equal(attention.k_observer.min_vals, exp_min_val) + assert torch.equal(min_vals, exp_min_val) if exp_max_val is not None: - assert torch.equal(attention.k_observer.max_vals, exp_max_val) + assert torch.equal(max_vals, exp_max_val) # check forward pass assert torch.allclose(output, exp_quant.to(output.dtype)) From c5aeca43afd12da6ed506b2d66699663b7ecc05d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 9 Oct 2025 12:37:34 -0400 Subject: [PATCH 16/23] squash Signed-off-by: Kyle Sayers --- setup.py | 2 +- .../modifiers/quantization/__init__.py | 1 - .../modifiers/quantization/calibration.py | 81 ++++--------------- .../quantization/quantization/mixin.py | 59 +++++++------- src/llmcompressor/modifiers/utils/hooks.py | 15 +++- src/llmcompressor/observers/helpers.py | 9 ++- .../transformers/kv_cache/test_kv_cache.py | 22 ++--- 7 files changed, 77 insertions(+), 112 deletions(-) diff --git a/setup.py b/setup.py index 3b68083a61..14632e6afe 100644 --- a/setup.py +++ b/setup.py @@ -160,7 +160,7 @@ def localversion_func(version: ScmVersion) -> str: "torchvision", "librosa==0.11.0", "soundfile", - "torchcodec", + #"torchcodec", # linting, formatting, and type checking "mypy~=1.10.0", "ruff~=0.4.8", diff --git a/src/llmcompressor/modifiers/quantization/__init__.py b/src/llmcompressor/modifiers/quantization/__init__.py index f6ad149fbb..1ca6912221 100644 --- a/src/llmcompressor/modifiers/quantization/__init__.py +++ b/src/llmcompressor/modifiers/quantization/__init__.py @@ -1,5 +1,4 @@ # ruff: noqa -from .cache import * from .gptq import * from .quantization import * diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 5540532c97..bc78596d3f 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,22 +1,17 @@ -import inspect -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch from compressed_tensors.quantization import ( DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationScheme, QuantizationStatus, QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from torch.nn import Module -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -25,13 +20,13 @@ "update_weight_zp_scale", "calibrate_input_hook", "calibrate_output_hook", - "calibrate_kv_cache_input_hook", - "calibrate_kv_cache_output_hook", - "initialize_quantized_kv_cache", "freeze_module_quantization", "apply_calibration_status", "reset_quantization_status", "update_weight_global_scale", + "calibrate_query_hook", + "calibrate_key_hook", + "calibrate_value_hook", ] @@ -151,8 +146,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if value.numel() == 0: return - quantization_scheme = getattr(module, "quantization_scheme", None) - quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None) + field_name = "input" if base_name != "output" else "output" # input,q,k,v,output + args_attr = f"quantization_scheme.{field_name}_activations" + quantization_args = getattr_chain(module, args_attr, None) calculate_qparams = True calculate_gparam = False @@ -202,60 +198,16 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): return output -def calibrate_kv_cache_input_hook( - module: Module, args: Any, kwargs: Dict[str, Any] -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """ - Hook to update inputs to attention layers when running - kv_cache quantization. Will update the passed in - kv_cache to singleton QuantizedKVParameterCache. - """ - kv_cache = getattr(module, "kv_cache") - if not hasattr(module, "_past_kv_name"): - # Determine which past KV parameter name to use once and cache it - # TODO: Find a better place to cache this - module._past_kv_name = ( - "past_key_value" # transformers#39956 - if "past_key_value" in inspect.signature(module.forward).parameters - else "past_key_values" - ) - - kwargs[module._past_kv_name] = kv_cache - kwargs["use_cache"] = False - return args, kwargs +def calibrate_query_hook(module: Module, query_states: torch.Tensor): + calibrate_activations(module, query_states, base_name="q") -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): - """ - Hook to update k_scale and v_scale parameters when running kv_cache quantization. - """ - kv_cache = getattr(module, "kv_cache") - k_scale = kv_cache.k_scales[module.layer_idx] - v_scale = kv_cache.v_scales[module.layer_idx] - update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale) - update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale) +def calibrate_key_hook(module: Module, key_states: torch.Tensor): + calibrate_activations(module, key_states, base_name="k") -def initialize_quantized_kv_cache(module: Module): - """ - Initialize a quantized kv_cache on a module (analogous to initializing an observer) - When a config specifying kv_cache quantization is applied to a model, the kv_cache - args are redefined as the output_activations targeting attention modules. - - This function should be called on attention modules with output_activations - """ - scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) - existing_kv_cache = getattr(module, "kv_cache", None) - - if ( - scheme is None - or not is_kv_cache_quant_scheme(scheme) - or isinstance(existing_kv_cache, QuantizedKVParameterCache) - ): - return - - quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) - setattr(module, "kv_cache", quantized_kv_cache) +def calibrate_value_hook(module: Module, value_states: torch.Tensor): + calibrate_activations(module, value_states, base_name="v") def apply_calibration_status(module: Module): @@ -284,16 +236,11 @@ def freeze_module_quantization(module: Module): return # remove observers - for name in ("input", "weight", "output"): + for name in ("input", "weight", "output", "q", "k", "v"): obs_name = f"{name}_observer" if hasattr(module, obs_name): delattr(module, obs_name) - # remove quantized kv_cache - kv_cache = getattr(module, "kv_cache", None) - if isinstance(kv_cache, QuantizedKVParameterCache): - delattr(module, "kv_cache") - module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index f37efb56a7..706c7c0744 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -1,6 +1,13 @@ from typing import Any, Dict, List, Optional, Set, Union import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, + register_key_hook, + register_query_hook, + register_value_hook, +) from compressed_tensors.quantization import ( DynamicType, QuantizationArgs, @@ -21,12 +28,12 @@ from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, calibrate_input_hook, - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, + calibrate_key_hook, calibrate_output_hook, + calibrate_query_hook, + calibrate_value_hook, freeze_module_quantization, initialize_observer, - initialize_quantized_kv_cache, reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -253,19 +260,21 @@ def _initialize_observers(self, module: torch.nn.Module): # input activations if input: - initialize_observer(module, base_name="input") + if not is_attention: + initialize_observer(module, base_name="input") + else: + if hasattr(module, IMPL_ATTR): + initialize_observer(module, base_name="q") + if hasattr(module, KV_CACHE_ATTR): + initialize_observer(module, base_name="k") + initialize_observer(module, base_name="v") # weight observers (used by `update_weight_zp_scale` or child modifier) if weight: initialize_observer(module, base_name="weight") - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - initialize_quantized_kv_cache(module) - # output activations - elif output: + if output: initialize_observer(module, base_name="output") def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: @@ -284,29 +293,19 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: # input activations if input: - hooks.add( - self.register_hook(module, calibrate_input_hook, "forward_pre") - ) - - # kv_cache activations. Within `apply_quantization_config`, the config is - # modified to use attention output quantization if a kv_cache_scheme exists - if is_attention and output: - hooks.add( - self.register_hook( - module, - calibrate_kv_cache_input_hook, - "forward_pre", - with_kwargs=True, + if not is_attention: + hooks.add( + self.register_hook(module, calibrate_input_hook, "forward_pre") ) - ) - hooks.add( - self.register_hook( - module, calibrate_kv_cache_output_hook, "forward" - ) - ) + else: + if hasattr(module, IMPL_ATTR): + hooks.add(register_query_hook(module, calibrate_query_hook)) + if hasattr(module, KV_CACHE_ATTR): + hooks.add(register_key_hook(module, calibrate_key_hook)) + hooks.add(register_value_hook(module, calibrate_value_hook)) # output activations - elif output: + if output: hooks.add(self.register_hook(module, calibrate_output_hook, "forward")) return hooks diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 98d5240e21..f3c6164918 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,6 +1,7 @@ import contextlib +from copy import deepcopy from functools import wraps -from typing import Any, Callable, ClassVar, Optional, Set, Union +from typing import Any, Callable, ClassVar, Dict, Optional, Set, Union import torch from loguru import logger @@ -39,6 +40,7 @@ class HooksMixin(BaseModel): # attached to global HooksMixin class _HOOKS_DISABLED: ClassVar[bool] = False _HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set() + _HOOKS_TO_MODIFIER: ClassVar[Dict[RemovableHandle, "HooksMixin"]] = dict() # attached to local subclasses _hooks: Set[RemovableHandle] = set() @@ -95,6 +97,7 @@ def wrapped_hook(*args, **kwargs): register_function = getattr(target, f"register_{hook_type}_hook") handle = register_function(wrapped_hook, **kwargs) self._hooks.add(handle) + self._HOOKS_TO_MODIFIER[handle] = self logger.debug(f"{self} added {handle}") return handle @@ -113,3 +116,13 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): hook.remove() self._hooks -= handles + for handle in handles: + self._HOOKS_TO_MODIFIER.pop(handle, None) + + @classmethod + def remove_hooks_by_id(cls, ids: Set[int]): + handles = deepcopy(cls._HOOKS_TO_MODIFIER) + for handle in handles: + if handle.id in ids: + modifier = cls._HOOKS_TO_MODIFIER[handle] + modifier.remove_hooks(set(handle)) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 4560da1b85..02f8e32bda 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -52,6 +52,8 @@ def flatten_for_calibration( def _flatten_weight( value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None ): + # value.shape = (num_rows, num_cols) + if args.strategy == QuantizationStrategy.TENSOR: # (1, 1, num_weight_elems) return value.reshape((1, 1, -1)) @@ -87,6 +89,8 @@ def _flatten_weight( def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, seq_len, hidden_dim) + if args.strategy == QuantizationStrategy.TENSOR: # (batch_size * seq_len, 1, hidden_dim) return value.reshape((-1, 1, value.size(-1))) @@ -111,10 +115,11 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, num_heads, seq_len, head_dim) + if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size, seq_len, num_heads, head_dim) # (batch_size * seq_len, 1, num_heads * head_dim) - return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2) if args.strategy == QuantizationStrategy.TOKEN: raise ValueError("Token quantization cannot be applied to attention") diff --git a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py index 80886e1dd8..58f139f342 100644 --- a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py +++ b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py @@ -3,7 +3,7 @@ import pytest from accelerate import init_empty_weights -from compressed_tensors.quantization import KVCacheScaleType, is_attention_module +from compressed_tensors.quantization import is_attention_module from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.utils.quantization_config import CompressedTensorsConfig @@ -14,7 +14,7 @@ NUM_CALIBRATION_SAMPLES = 16 MAX_SEQUENCE_LENGTH = 512 DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" +DATASET_SPLIT = f"train_sft[:{NUM_CALIBRATION_SAMPLES}]" MODEL_IDS = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -49,9 +49,11 @@ def _oneshot_fixture(tmp_path: Path): symmetric=symmetric, ) oneshot_args = dict( - dataset="open_platypus", recipe=recipe, - num_calibration_samples=16, + dataset="open_platypus", + splits={"calibration": f"train[:{NUM_CALIBRATION_SAMPLES}]"}, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + max_seq_length=MAX_SEQUENCE_LENGTH, ) for model_id in MODEL_IDS: oneshot_args["output_dir"] = os.path.join(tmp_path, model_id) @@ -161,8 +163,8 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -200,8 +202,8 @@ def test_kv_cache_gptq_config_format(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 @@ -240,7 +242,7 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path): for name, submodule in model.named_modules(): if is_attention_module(submodule): counts += 1 - assert hasattr(submodule, KVCacheScaleType.VALUE.value) - assert hasattr(submodule, KVCacheScaleType.KEY.value) + assert hasattr(submodule, "v_scale") + assert hasattr(submodule, "k_scale") assert counts > 0 From 2c1e55a0bc1cd57fb77483407e8abd019c6f0fe0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 9 Oct 2025 12:40:50 -0400 Subject: [PATCH 17/23] reduce diff Signed-off-by: Kyle Sayers --- setup.py | 2 +- src/llmcompressor/modifiers/utils/hooks.py | 15 +-------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index 14632e6afe..3b68083a61 100644 --- a/setup.py +++ b/setup.py @@ -160,7 +160,7 @@ def localversion_func(version: ScmVersion) -> str: "torchvision", "librosa==0.11.0", "soundfile", - #"torchcodec", + "torchcodec", # linting, formatting, and type checking "mypy~=1.10.0", "ruff~=0.4.8", diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index f3c6164918..98d5240e21 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,7 +1,6 @@ import contextlib -from copy import deepcopy from functools import wraps -from typing import Any, Callable, ClassVar, Dict, Optional, Set, Union +from typing import Any, Callable, ClassVar, Optional, Set, Union import torch from loguru import logger @@ -40,7 +39,6 @@ class HooksMixin(BaseModel): # attached to global HooksMixin class _HOOKS_DISABLED: ClassVar[bool] = False _HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set() - _HOOKS_TO_MODIFIER: ClassVar[Dict[RemovableHandle, "HooksMixin"]] = dict() # attached to local subclasses _hooks: Set[RemovableHandle] = set() @@ -97,7 +95,6 @@ def wrapped_hook(*args, **kwargs): register_function = getattr(target, f"register_{hook_type}_hook") handle = register_function(wrapped_hook, **kwargs) self._hooks.add(handle) - self._HOOKS_TO_MODIFIER[handle] = self logger.debug(f"{self} added {handle}") return handle @@ -116,13 +113,3 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): hook.remove() self._hooks -= handles - for handle in handles: - self._HOOKS_TO_MODIFIER.pop(handle, None) - - @classmethod - def remove_hooks_by_id(cls, ids: Set[int]): - handles = deepcopy(cls._HOOKS_TO_MODIFIER) - for handle in handles: - if handle.id in ids: - modifier = cls._HOOKS_TO_MODIFIER[handle] - modifier.remove_hooks(set(handle)) From 0b6624f5f0398103ec2a64d150bbd6ce51281943 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 9 Oct 2025 14:29:22 -0400 Subject: [PATCH 18/23] remove irrelevant tests Signed-off-by: Kyle Sayers --- .../modifiers/calibration/test_cache.py | 118 ------------------ .../modifiers/calibration/test_kv_cache.py | 94 -------------- 2 files changed, 212 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/calibration/test_cache.py delete mode 100644 tests/llmcompressor/modifiers/calibration/test_kv_cache.py diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py deleted file mode 100644 index 70f0e61259..0000000000 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs - -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache -from llmcompressor.observers import Observer - - -def test_is_quantized_cache_singleton(): - """ - Check if quantized_cache is a singleton, used for - passing in QuantizedKVParameterCache to the forward call of - the model's self_attn - """ - - args = QuantizationArgs() - cache = QuantizedKVParameterCache(args) - observer = args.observer - observer = Observer.load_from_registry(observer, base_name="k", args=args) - - tensor = torch.tensor([1, 2, 3]) - cache.k_scales.append(tensor) - cache.k_observers.append(observer) - - same_cache = QuantizedKVParameterCache(args) - - assert len(cache.k_scales) == len(same_cache.k_scales) - assert torch.equal(cache.k_scales[0], same_cache.k_scales[0]) - - assert cache.k_observers == same_cache.k_observers - assert hex(id(cache.k_observers[0])) == hex(id(same_cache.k_observers[0])) - - cache.reset() - - -def test_update(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - denom = (2 ** (num_bits) - 1) / 2 - expected_k_scale = torch.tensor([max_key_states_val / denom]) - expected_v_scale = torch.tensor([max_value_states_val / denom]) - - assert cache.k_scales[0] == expected_k_scale - assert cache.v_scales[0] == expected_v_scale - - # new attn layer - layer_idx = 1 - cache.update(key_states, value_states, layer_idx) - - assert len(cache.k_scales) == 2 - assert len(cache.v_scales) == 2 - - assert len(cache.k_observers) == 2 - assert len(cache.v_observers) == 2 - - cache.reset() - - -def test_cache_reset(): - num_bits = 8 - args = QuantizationArgs(num_bits=num_bits, symmetric=True) - cache = QuantizedKVParameterCache(args) - - max_key_states_val = 1.0 - max_value_states_val = 2.0 - key_states = torch.cat( - (max_key_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - value_states = torch.cat( - (max_value_states_val * torch.ones(1, 2, 2), torch.ones(1, 2, 2)), dim=0 - ) - layer_idx = 0 - - cache.update(key_states, value_states, layer_idx) - assert len(cache.k_scales) == 1 - assert len(cache.v_scales) == 1 - - assert len(cache.k_observers) == 1 - assert len(cache.v_observers) == 1 - - cache.reset() - - # new instance, different memory addr - different_cache = QuantizedKVParameterCache(args) - - assert len(different_cache.k_scales) == 0 - assert len(different_cache.v_scales) == 0 - - assert len(different_cache.k_observers) == 0 - assert len(different_cache.v_observers) == 0 - - assert hex(id(cache)) != hex(id(different_cache)) diff --git a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py b/tests/llmcompressor/modifiers/calibration/test_kv_cache.py deleted file mode 100644 index b22e7ec401..0000000000 --- a/tests/llmcompressor/modifiers/calibration/test_kv_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, - is_attention_module, -) -from transformers import AutoModelForCausalLM - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_kv_cache_input_hook, - calibrate_kv_cache_output_hook, - freeze_module_quantization, - initialize_quantized_kv_cache, -) - -config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "kv_cache_scheme": { - "num_bits": 8, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "config_groups": { - "group_1": { - "weights": { - "num_bits": 4, - "type": "int", - "symmetric": True, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, -} - - -def _prep_for_calibration(module: torch.nn.Module): - if is_attention_module(module): - module.register_forward_pre_hook( - calibrate_kv_cache_input_hook, with_kwargs=True - ) - module.register_forward_hook(calibrate_kv_cache_output_hook) - module.quantization_status = QuantizationStatus.CALIBRATION - - -@pytest.mark.parametrize("config", [config]) -def test_kv_cache_quantization(config): - sample = { - name: torch.ones((1, 32)).long() - for name in ["input_ids", "attention_mask", "labels"] - } - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - torch_dtype="auto", - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - model.apply(initialize_quantized_kv_cache) - model.apply(_prep_for_calibration) - - with torch.no_grad(): - _ = model(**sample) - - model.apply(freeze_module_quantization) - - reloaded_config = QuantizationConfig.from_pretrained(model) - - assert ( - config.kv_cache_scheme.model_dump().keys() - == reloaded_config.kv_cache_scheme.model_dump().keys() - ) - assert list(config.kv_cache_scheme.model_dump().values()) == list( - reloaded_config.kv_cache_scheme.model_dump().values() - ) From 241c31059aa84a1e03e72e70243be4cb3f4164bd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 16:48:29 -0400 Subject: [PATCH 19/23] remove kv cache Signed-off-by: Kyle Sayers --- .../modifiers/quantization/cache.py | 218 ------------------ 1 file changed, 218 deletions(-) delete mode 100644 src/llmcompressor/modifiers/quantization/cache.py diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py deleted file mode 100644 index 53eca8d075..0000000000 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Quantized key-value cache implementation for efficient inference. - -Provides quantized KV cache classes extending HuggingFace's -DynamicCache with quantization support. Enables memory-efficient attention -mechanisms by quantizing cached key and value tensors during model -inference with configurable quantization strategies. -""" - -from typing import Any, Dict, List, Optional, Tuple - -from compressed_tensors.quantization import KVCacheScaleType, QuantizationArgs -from torch import Tensor -from transformers import DynamicCache - -from llmcompressor.observers import Observer - - -class QuantizedKVParameterCache(DynamicCache): - """ - Quantized KV cache used in the forward call based on HF's dynamic cache. - Quantization strategy (tensor, group, channel) set from Quantization arg's strategy - Singleton, so that the same cache gets reused in all forward call of self_attn. - Each time forward is called, .update() is called, and ._quantize(), ._dequantize() - gets called appropriately. - The size of tensor is - `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - - - Triggered by adding kv_cache_scheme in the recipe. - - Example: - - ```python3 - recipe = ''' - quant_stage: - quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: 8 - type: float - strategy: tensor - dynamic: false - symmetric: true - ''' - - """ - - _instance = None - _initialized = False - - def __new__(cls, *args, **kwargs): - """Singleton""" - if cls._instance is None: - cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) - return cls._instance - - def __init__(self, quantization_args: QuantizationArgs): - if not self._initialized: - super().__init__() - - self.quantization_args = quantization_args - - self.k_observers: List[Observer] = [] - self.v_observers: List[Observer] = [] - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[Tensor] = [] - self.v_scales: List[Tensor] = [] - - self.k_zps: List[Tensor] = [] - self.v_zps: List[Tensor] = [] - - self._initialized = True - - def update( - self, - key_states: Tensor, - value_states: Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Get the k_scale and v_scale and output the - fakequant-ed key_states and value_states - """ - - if len(self.k_observers) <= layer_idx: - k_observer = Observer.load_from_registry( - self.quantization_args.observer, - base_name="k", - args=self.quantization_args, - ) - v_observer = Observer.load_from_registry( - self.quantization_args.observer, - base_name="v", - args=self.quantization_args, - ) - - # NOTE: User may ignore some layers in configuration, - # meaning len(self.k_observers) <= layer_idx-1 - # Must account for that case by padding list so that - # index of lists corresponds to layer_idx - _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) - _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - - q_key_states = self._quantize( - key_states.contiguous(), KVCacheScaleType.KEY, layer_idx - ) - q_value_states = self._quantize( - value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx - ) - - qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx) - qdq_value_states = self._dequantize( - q_value_states, KVCacheScaleType.VALUE, layer_idx - ) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def reset_states(self): - """reset the kv states (used in calibration)""" - self.key_cache: List[Tensor] = [] - self.value_cache: List[Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen - self._seen_tokens = 0 - self._quantized_key_cache: List[Tensor] = [] - self._quantized_value_cache: List[Tensor] = [] - - def reset(self): - """ - Reset the instantiation, create new instance on init - """ - QuantizedKVParameterCache._instance = None - QuantizedKVParameterCache._initialized = False - - def _quantize(self, tensor, kv_type, layer_idx): - """Quantizes a key/value using a defined quantization method.""" - from compressed_tensors.quantization.lifecycle.forward import quantize - - if kv_type == KVCacheScaleType.KEY: # key type - observer = self.k_observers[layer_idx] - scales = self.k_scales - zps = self.k_zps - else: - assert kv_type == KVCacheScaleType.VALUE - observer = self.v_observers[layer_idx] - scales = self.v_scales - zps = self.v_zps - - scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) - - q_tensor = quantize( - x=tensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return q_tensor - - def _dequantize(self, qtensor, kv_type, layer_idx): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - from compressed_tensors.quantization.lifecycle.forward import dequantize - - if kv_type == KVCacheScaleType.KEY: - scale = self.k_scales[layer_idx] - zp = self.k_zps[layer_idx] - else: - assert kv_type == KVCacheScaleType.VALUE - scale = self.v_scales[layer_idx] - zp = self.v_zps[layer_idx] - - qdq_tensor = dequantize( - x_q=qtensor, - scale=scale, - zero_point=zp, - args=self.quantization_args, - ) - return qdq_tensor - - -# NOTE: Using _ suffix to denote l is modified in place -def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: - """ - Append value val to list lst at index idx, right padding if necessary - Needed because user may ignore some layers in configuration, meaning - len(lst) <= idx-1 - - >>> _pad_and_append_at_idx_([0,1,2], 5, 5) - [0, 1, 2, None, None, 5] - >>> _pad_and_append_at_idx_([0,1,2], 3, 8) - [0, 1, 2, 8] - >>> _pad_and_append_at_idx_([0,1,2], 1, 5) - [0, 5, 2] - """ - num_to_pad = idx - len(lst) + 1 - if num_to_pad > 0: - lst += [None] * num_to_pad - lst[idx] = val - return lst From 57bee2752d1ff223584cd1a90cb616149618ea7c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 17:30:37 -0400 Subject: [PATCH 20/23] support attn_head Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/helpers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 02f8e32bda..71fa75d89d 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -85,6 +85,9 @@ def _flatten_weight( .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("Attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" @@ -111,6 +114,9 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("Attention head quantization cannot be applied to activations") + assert False, f"Unknown strategy {args.strategy}" @@ -133,4 +139,8 @@ def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, 1, head_dim) + return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" From b2bb9fe7631cd2b41038a49ab8677c4bbfb81ad3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 13 Oct 2025 23:47:47 -0400 Subject: [PATCH 21/23] fp4 Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 71fa75d89d..a8aa8bddea 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -134,7 +134,8 @@ def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): raise ValueError("Channel quantization cannot be applied to attention") if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - raise ValueError("Group quantization cannot be applied to attention") + # batch_size * num_heads * seq_len, num_groups, group_size) + return value.flatten(0, 2).unflatten(-1, (-1, args.group_size)) if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") From 649746a4e2b429ca14f02428bf7121176e73599a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 14 Oct 2025 00:06:31 -0400 Subject: [PATCH 22/23] add tests Signed-off-by: Kyle Sayers --- .../modifiers/calibration/test_lifecycle.py | 71 +++++++++++++++---- 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py index 0878975f67..087c54e696 100644 --- a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py +++ b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py @@ -283,21 +283,57 @@ class MockAttention(torch.nn.Module): strategy="tensor", ), torch.tensor([0.0]), - torch.tensor([11.0]), + torch.tensor([23.0]), torch.tensor( [ [ - [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], - [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + [ + [0.0000, 0.0000, 3.0625, 3.0625], + [3.0625, 6.1250, 6.1250, 6.1250], + [9.1875, 9.1875, 9.1875, 12.2500], + ], + [ + [12.2500, 12.2500, 15.3125, 15.3125], + [15.3125, 18.3750, 18.3750, 18.3750], + [21.5000, 21.5000, 21.5000, 21.5000], + ], ] ] ), - 0.19, + 0.81, ), # static token is not supported # channel is not supported # group is not supported - # tensor group is not supported + ( + QuantizationArgs( + num_bits=4, + type="float", # must be fp4 + symmetric=True, + strategy="tensor_group", + dynamic="local", + group_size=2, + ), + torch.tensor([0.0]), + torch.tensor([23.0]), + torch.tensor( + [ + [ + [ + [0.0000, 1.0234, 2.0469, 3.0781], + [3.2812, 4.9375, 4.9375, 7.3750], + [9.0000, 9.0000, 10.6875, 10.6875], + ], + [ + [13.1250, 13.1250, 14.7500, 14.7500], + [16.3750, 16.3750, 19.7500, 19.7500], + [21.3750, 21.3750, 23.0000, 23.0000], + ], + ] + ] + ), + 0.55, + ), # block is not supported ], ) @@ -305,28 +341,37 @@ def test_static_attention_quantization( args, exp_min_val, exp_max_val, exp_quant, exp_loss ): """ - input = tensor([[[[ 0., 1., 2.], - [ 3., 4., 5.]], - [[ 6., 7., 8.], - [ 9., 10., 11.]]]]) + input = tensor([[[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], + + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]]) """ # set up activation (and identity weight) - batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4 input = torch.arange( - (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 - ).reshape((batch_size, seq_len, num_heads, head_dim)) + (batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, num_heads, seq_len, head_dim)) attention = MockAttention() # initialize quantization parameters scheme = QuantizationScheme(targets=[], input_activations=args) initialize_qparams( - attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16 ) attention.quantization_scheme = scheme attention.quantization_status = QuantizationStatus.INITIALIZED initialize_observer(attention, "k") # calibrate quantization parameters + if hasattr(attention, "k_global_scale"): + global_scale, min_vals, max_vals = ( + attention.k_observer._get_global_scale_with_minmax(input) + ) + attention.k_global_scale.data = global_scale + if scheme.input_activations.dynamic is False: scale, zero_point, min_vals, max_vals = ( attention.k_observer._forward_with_minmax(input) From 565f0312c9c274adfe9127c0ce51362de1361b86 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 14 Oct 2025 00:14:48 -0400 Subject: [PATCH 23/23] fix typo Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index a8aa8bddea..420fe09c67 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -134,7 +134,7 @@ def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): raise ValueError("Channel quantization cannot be applied to attention") if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - # batch_size * num_heads * seq_len, num_groups, group_size) + # (batch_size * num_heads * seq_len, num_groups, group_size) return value.flatten(0, 2).unflatten(-1, (-1, args.group_size)) if args.strategy == QuantizationStrategy.BLOCK: