From 7a3d43ad0b3ff7fa3e06e680894faa740c337374 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 09:20:57 -0400 Subject: [PATCH 1/9] add tests Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/min_max.py | 27 +- .../modifiers/calibration/test_observers.py | 324 +++++++++++++++++- 2 files changed, 343 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index ce5c0e7790..b806b6d7b6 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,9 +1,9 @@ -from typing import Any, Optional, Tuple +from typing import Any, Iterable, Optional, Tuple, Union import torch from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam -from compressed_tensors.utils import deprecated +from compressed_tensors.utils import deprecated, patch_attr from llmcompressor.observers.base import Observer @@ -58,6 +58,8 @@ def calculate_updated_min_max( # early stopping, save some computation and memory if self.averaging_constant == 1.0: + self.min_val[tensor_id] = min_val + self.max_val[tensor_id] = max_val return min_val, max_val running_min_val = self.min_val.get(tensor_id, None) @@ -86,9 +88,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: :return: updated global scale derived from the observed tensor """ - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed - ) + # patch to avoid affecting running means + with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}): + updated_min_val, updated_max_val = self.calculate_updated_min_max( + observed=observed + ) return generate_gparam( updated_min_val=updated_min_val, updated_max_val=updated_max_val ) @@ -126,14 +130,23 @@ def calculate_qparams( def get_qparams_along_dim( self, observed: torch.Tensor, - dim: int, + dim: Union[int, Iterable[int]], tensor_id: Optional[Any] = None, global_scale: Optional[torch.Tensor] = None, ): """ Calculate quantization parameters along the specified dimension """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) + # cast to set + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + # convert negative dims + dim = [d if d >= 0 else observed.ndim + d for d in dim] + + # reduce all dimensions except the the one passed as argument to this function + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) return self.calculate_qparams( observed, reduce_dims=reduce_dims, diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index a742a48b21..c7748cd1cc 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -6,7 +6,12 @@ initialize_module_for_quantization, ) -from llmcompressor.modifiers.quantization.calibration import initialize_observer +from llmcompressor.modifiers.quantization.calibration import ( + calibrate_input_hook, + initialize_observer, + update_weight_global_scale, + update_weight_zp_scale, +) @pytest.mark.parametrize( @@ -59,3 +64,320 @@ def test_observers_update(shape, group_size, actorder): def assert_alike(a, b): assert a.dtype == b.dtype assert a.shape == b.shape + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(23.0)}, + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + observer="minmax", + ), + {"default": torch.tensor([[0], [6], [12], [18]])}, + {"default": torch.tensor([[5], [11], [17], [23]])}, + torch.tensor( + [ + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0], [6], [12], [18]]), + 1: torch.tensor([[3], [9], [15], [21]]), + }, + { + "default": torch.tensor([[2], [8], [14], [20]]), + 1: torch.tensor([[5], [11], [17], [23]]), + }, + torch.tensor( + [ + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", # requires float4 + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0], [6], [12], [18]]), + 1: torch.tensor([[3], [9], [15], [21]]), + }, + { + "default": torch.tensor([[2], [8], [14], [20]]), + 1: torch.tensor([[5], [11], [17], [23]]), + }, + torch.tensor( + [ + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], + ], + ), + 1.1, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[2, 3], + observer="minmax", + ), + { + "block_0_0": torch.tensor([[0]]), + "block_0_1": torch.tensor([[3]]), + "block_1_0": torch.tensor([[12]]), + "block_1_1": torch.tensor([[15]]), + }, + { + "block_0_0": torch.tensor([[8]]), + "block_0_1": torch.tensor([[11]]), + "block_1_0": torch.tensor([[20]]), + "block_1_1": torch.tensor([[23]]), + }, + torch.tensor( + [ + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.5, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="token", # equivalent to tensor + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(23.0)}, + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + ], +) +def test_static_weight_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + weight = tensor([[ 0, 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23]]) + """ + # set up weight + input_size, output_size = 6, 4 + linear = torch.nn.Linear(input_size, output_size, bias=False) + linear.weight.data = torch.arange( + input_size * output_size, dtype=torch.bfloat16 + ).reshape(output_size, input_size) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + + # calibrate quantization parameters + initialize_observer(linear, "weight") + update_weight_global_scale(linear) + update_weight_zp_scale(linear) + + observer = getattr(linear, "weight_observer") + assert ( + observer.min_val.keys() + == observer.max_val.keys() + == exp_min_val.keys() + == exp_max_val.keys() + ) + for key in observer.min_val.keys(): + assert torch.equal(observer.min_val[key], exp_min_val[key]) + assert torch.equal(observer.max_val[key], exp_max_val[key]) + + # forward pass + input = torch.eye(input_size, dtype=torch.bfloat16) + output = linear(input) + + print(output.T) + print(torch.nn.functional.mse_loss(output.T, linear.weight)) + assert torch.allclose(output.T, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(5.0)}, + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), + 0.06, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="token", # equivalent to tensor + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(5.0)}, + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), + 0.06, + ), + # channel is not supported, but is in principle equivalent to token/tensor + # ( + # QuantizationArgs( + # num_bits=4, + # type="int", + # symmetric=True, + # strategy="group", + # group_size=3, + # observer="minmax", + # ), + # { + # "default": torch.tensor([[0]]), + # 1: torch.tensor([[3]]), + # }, + # { + # "default": torch.tensor([[2]]), + # 1: torch.tensor([[5]]), + # }, + # torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]), + # 0.04, + # ), + # ( + # QuantizationArgs( + # num_bits=4, + # type="float", # tensor group requires FP4 + # symmetric=True, + # strategy="tensor_group", + # group_size=3, + # observer="minmax", + # ), + # { + # "default": torch.tensor([[0]]), + # 1: torch.tensor([[3]]), + # }, + # { + # "default": torch.tensor([[2]]), + # 1: torch.tensor([[5]]), + # }, + # torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]), + # 0.1, + # ), + # block is not supported, but is in principle similar to group + ], +) +def test_static_activation_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[ 0, 1, 2, 3, 4, 5]]) + """ + # set up activation (and identity weight) + input_size = 6 + input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0) + linear = torch.nn.Linear(input_size, input_size, bias=False) + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + + # calibrate quantization parameters + initialize_observer(linear, "input") + linear.register_forward_pre_hook(calibrate_input_hook) + + # calibration forward pass + output = linear(input) + + # check calibration + observer = getattr(linear, "input_observer") + assert ( + observer.min_val.keys() + == observer.max_val.keys() + == exp_min_val.keys() + == exp_max_val.keys() + ) + for key in observer.min_val.keys(): + assert torch.equal(observer.min_val[key], exp_min_val[key]) + assert torch.equal(observer.max_val[key], exp_max_val[key]) + + # check forward pass + print(args.strategy) + print(output) + print(torch.nn.functional.mse_loss(output, input)) + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss From 25e54a5e7bbe5c9f06ac47aef0777dc7d4503b35 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 09:22:42 -0400 Subject: [PATCH 2/9] remove reset Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/quantization/calibration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 2900f6bd3a..f94b242697 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module): should_calculate_gparam=True, should_calculate_qparams=False, ) - module.weight_observer.reset() def update_weight_zp_scale(module: Module): From 06ece86c6a796b0145cf0f4b11d8489c07965f7e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 09:31:08 -0400 Subject: [PATCH 3/9] remove prints Signed-off-by: Kyle Sayers --- tests/llmcompressor/modifiers/calibration/test_observers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index c7748cd1cc..57517ef691 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -260,8 +260,6 @@ def test_static_weight_quantization( input = torch.eye(input_size, dtype=torch.bfloat16) output = linear(input) - print(output.T) - print(torch.nn.functional.mse_loss(output.T, linear.weight)) assert torch.allclose(output.T, exp_quant.to(output.dtype)) assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss @@ -376,8 +374,5 @@ def test_static_activation_quantization( assert torch.equal(observer.max_val[key], exp_max_val[key]) # check forward pass - print(args.strategy) - print(output) - print(torch.nn.functional.mse_loss(output, input)) assert torch.allclose(output, exp_quant.to(output.dtype)) assert torch.nn.functional.mse_loss(output, input) <= exp_loss From 178d0aed6bc496cf8854e8fdb7b6963b503e7fc8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 09:48:58 -0400 Subject: [PATCH 4/9] remove commented code Signed-off-by: Kyle Sayers --- .../modifiers/calibration/test_observers.py | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index 57517ef691..fb49ba5daa 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -294,46 +294,8 @@ def test_static_weight_quantization( 0.06, ), # channel is not supported, but is in principle equivalent to token/tensor - # ( - # QuantizationArgs( - # num_bits=4, - # type="int", - # symmetric=True, - # strategy="group", - # group_size=3, - # observer="minmax", - # ), - # { - # "default": torch.tensor([[0]]), - # 1: torch.tensor([[3]]), - # }, - # { - # "default": torch.tensor([[2]]), - # 1: torch.tensor([[5]]), - # }, - # torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]), - # 0.04, - # ), - # ( - # QuantizationArgs( - # num_bits=4, - # type="float", # tensor group requires FP4 - # symmetric=True, - # strategy="tensor_group", - # group_size=3, - # observer="minmax", - # ), - # { - # "default": torch.tensor([[0]]), - # 1: torch.tensor([[3]]), - # }, - # { - # "default": torch.tensor([[2]]), - # 1: torch.tensor([[5]]), - # }, - # torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]), - # 0.1, - # ), + # group is not yet supported + # tensor_group is not yet supported # block is not supported, but is in principle similar to group ], ) From 7d81a79458ea09a5e5e401296493cc6cee2c0e49 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 17 Sep 2025 10:03:25 -0400 Subject: [PATCH 5/9] increase limit Signed-off-by: Kyle Sayers --- tests/llmcompressor/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llmcompressor/conftest.py b/tests/llmcompressor/conftest.py index f078fd0ae9..fb5629812a 100644 --- a/tests/llmcompressor/conftest.py +++ b/tests/llmcompressor/conftest.py @@ -61,7 +61,7 @@ def check_for_created_files(): f"Created files: {set(end_files_root) - set(start_files_root)}" ) - max_allowed_sized_temp_files_megabytes = 1 + max_allowed_sized_temp_files_megabytes = 1.5 end_files_temp = _get_files(directory=tempfile.gettempdir()) created_temp_files = set(end_files_temp) - set(start_files_temp) # pytest temp files are automatically deleted, exclude from size calculation From 05d780a1d1c88cd244433e876705aedb97921200 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 10:45:09 +0100 Subject: [PATCH 6/9] Update src/llmcompressor/observers/min_max.py Co-authored-by: Brian Dellabetta --- src/llmcompressor/observers/min_max.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index b806b6d7b6..5841202e06 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -145,7 +145,7 @@ def get_qparams_along_dim( # convert negative dims dim = [d if d >= 0 else observed.ndim + d for d in dim] - # reduce all dimensions except the the one passed as argument to this function + # reduce all dimensions except the ones passed as argument to this function reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) return self.calculate_qparams( observed, From dd91329779603824b7826a6956f361e43880da23 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 07:11:22 -0400 Subject: [PATCH 7/9] revert global disjointness Signed-off-by: Kyle Sayers --- .../modifiers/quantization/calibration.py | 1 + src/llmcompressor/observers/base.py | 4 ++++ src/llmcompressor/observers/min_max.py | 13 ++++++------- tests/llmcompressor/conftest.py | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index f94b242697..2900f6bd3a 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -147,6 +147,7 @@ def update_weight_global_scale(module: Module): should_calculate_gparam=True, should_calculate_qparams=False, ) + module.weight_observer.reset() def update_weight_zp_scale(module: Module): diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index aa9e1caab4..3160184ca3 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -51,8 +51,12 @@ def forward( :return: tuple of scale and zero point based on last observed value """ self.record_observed_tokens(observed) + if should_calculate_gparam: + # NOTE: this function updates running min/max values, which leads to + # running values updating twice return self.get_gparam(observed=observed) + return self.get_qparams( observed=observed, g_idx=g_idx, diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index 5841202e06..7f94f47888 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -3,7 +3,7 @@ import torch from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam -from compressed_tensors.utils import deprecated, patch_attr +from compressed_tensors.utils import deprecated from llmcompressor.observers.base import Observer @@ -87,12 +87,11 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: :param observed: observed tensor to calculate quantization parameters for :return: updated global scale derived from the observed tensor """ - - # patch to avoid affecting running means - with patch_attr(self, "min_val", {}), patch_attr(self, "max_val", {}): - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed - ) + # NOTE: this function updates running min/max values, which leads to + # running values updating twice + updated_min_val, updated_max_val = self.calculate_updated_min_max( + observed=observed + ) return generate_gparam( updated_min_val=updated_min_val, updated_max_val=updated_max_val ) diff --git a/tests/llmcompressor/conftest.py b/tests/llmcompressor/conftest.py index fb5629812a..f078fd0ae9 100644 --- a/tests/llmcompressor/conftest.py +++ b/tests/llmcompressor/conftest.py @@ -61,7 +61,7 @@ def check_for_created_files(): f"Created files: {set(end_files_root) - set(start_files_root)}" ) - max_allowed_sized_temp_files_megabytes = 1.5 + max_allowed_sized_temp_files_megabytes = 1 end_files_temp = _get_files(directory=tempfile.gettempdir()) created_temp_files = set(end_files_temp) - set(start_files_temp) # pytest temp files are automatically deleted, exclude from size calculation From a6b384265e15bd423f3621705d4d8ec6df96fe64 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 07:54:00 -0400 Subject: [PATCH 8/9] remove safe permute Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 3160184ca3..d5a1c4c862 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -10,7 +10,6 @@ ) from compressed_tensors.quantization.utils import is_fp4 from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import safe_permute from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -56,7 +55,7 @@ def forward( # NOTE: this function updates running min/max values, which leads to # running values updating twice return self.get_gparam(observed=observed) - + return self.get_qparams( observed=observed, g_idx=g_idx, @@ -172,8 +171,7 @@ def get_qparams( group_indices, group_sizes = torch.unique(g_idx, return_counts=True) group_sizes = group_sizes[torch.argsort(group_indices)] - perm = torch.argsort(g_idx) - observed = safe_permute(observed, perm, dim=1) + observed = observed.index_select(g_idx, -1) # TODO: experiment with vectorizing for loop for performance end = 0 From c0984474f8584409335513eeba0ab0f5fdf2d3dc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Sep 2025 08:04:13 -0400 Subject: [PATCH 9/9] fix typo Signed-off-by: Kyle Sayers --- src/llmcompressor/observers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index d5a1c4c862..ea325c7dc2 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -171,7 +171,7 @@ def get_qparams( group_indices, group_sizes = torch.unique(g_idx, return_counts=True) group_sizes = group_sizes[torch.argsort(group_indices)] - observed = observed.index_select(g_idx, -1) + observed = observed.index_select(-1, g_idx) # TODO: experiment with vectorizing for loop for performance end = 0