From 0125ffc2c8ed8b2c8def7cd41041006672063857 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 10 Nov 2025 13:48:43 -0800 Subject: [PATCH] Revert "Enable `PerRow(axis)` to support axes other than `-1` (#3303)" This reverts commit cfbe695e4dad06d1177a5bea99190332cfc2efe8. --- .../workflows/float8/test_float8_tensor.py | 65 ------------------- test/quantization/test_quant_primitives.py | 25 ------- torchao/quantization/granularity.py | 23 +++---- .../workflows/float8/float8_tensor.py | 4 +- torchao/quantization/utils.py | 6 +- torchao/testing/utils.py | 12 ++-- 6 files changed, 15 insertions(+), 120 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4bc106a60f..1b91875359 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -15,7 +15,6 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests -from torchao.core.config import config_from_dict, config_to_dict from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -635,44 +634,6 @@ def forward(self, x): sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") - @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") - def test_bmm_weight_in_bkn_layout(self): - # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) - # and contigous with that shape. Since the `K` dimension is not last, we - # need to specify granularity with `PerRow(1)`. - - # only support per row quantization - granularity = [PerRow(), PerRow(1)] - config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) - - class Model(torch.nn.Module): - def __init__(self, weight): - super().__init__() - self.weight = weight - - def forward(self, x): - return torch.bmm(x, self.weight) - - dtype = torch.bfloat16 - device = "cuda" - - B, M, K, N = 10, 32, 128, 256 - - input = torch.randn(B, M, K, dtype=dtype, device=device) - weight = torch.randn(B, K, N, dtype=dtype, device=device) - m = Model(weight).eval() - original = m(input) - quantize_(m, config, filter_fn=lambda x, fqn: True) - - assert m.weight.scale.shape == (B, 1, N), ( - f"unexpected scale shape {m.weight.scale.shape}" - ) - - quantized = m(input) - sqnr = compute_error(original, quantized) - self.assertTrue(sqnr > 20) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( "sizes", @@ -1046,32 +1007,6 @@ def test_transpose(self): self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) - def test_per_row_config_before_dim(self): - """ - Test that loading a serialized config of `PerRow` before the `dim` - argument was introduced works properly - """ - - # create a config with PerRow granularity - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ) - - # serialize it - config_ser = config_to_dict(config) - - # manually modify the serialized config to match v1 - # reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335 - del config_ser["_data"]["granularity"][0]["_data"]["dim"] - del config_ser["_data"]["granularity"][1]["_data"]["dim"] - assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0 - assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0 - - # load the modified version, verify that granularity is as expected - config_deser = config_from_dict(config_ser) - assert config_deser.granularity[0].dim == -1 - assert config_deser.granularity[1].dim == -1 - common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index cc6b7fff91..5f7895b4ea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,7 +10,6 @@ import torch -from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -28,7 +27,6 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, - get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -846,29 +844,6 @@ def test_float8_blockwise_scaling(self): torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) - def test_float8_rowwise_scaling_3d_weight_axis_1(self): - """ - Test scaling a weight with shape (B, K, N) and row-major memory layout - across the K dimension. - """ - - B, K, N = 8, 16, 32 - hp_tensor = torch.randn(B, K, N, dtype=torch.float) - - granularity = PerRow(1) - block_size = get_block_size(hp_tensor.shape, granularity) - scale = _choose_scale_float8( - hp_tensor, - float8_dtype=torch.float8_e4m3fn, - block_size=block_size, - hp_value_lb=None, - hp_value_ub=None, - ) - data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) - - assert scale.shape == (B, 1, N) - assert data.shape == (B, K, N) - if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 97d9c07b6f..d83032d7be 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -39,14 +39,12 @@ class PerAxis(Granularity): This granularity type calculates different quantization parameters along a specified axis of the tensor. - Examples: - * input_tensor shape [A, B], axis 0 -> scale_shape [A, 1] - * input_tensor shape [A, B], axis 1 -> scale_shape [1, B] - * input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1] + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. Attributes: - axis (int): The axis which is kept, reduction is performed across all - the other axes + axis (int): The axis along which reduction is performed. """ axis: int @@ -78,17 +76,12 @@ class PerRow(Granularity): """ Represents row-wise granularity in quantization. - Examples: - * input_tensor shape [A, B], dim 0 -> scale_shape [1, B] - * input_tensor shape [A, B], dim 1 -> scale_shape [A, 1] - * input_tensor shape [A, B], dim -1 -> scale_shape [A, 1] - * input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C] - - Attributes: - dim (int): The dim which is reduced across, all other dims are kept + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). """ - dim: int = -1 + pass @dataclass(frozen=True) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index abb9ddc1f9..a9c7af34b3 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -179,8 +179,6 @@ def from_hp( and _is_fbgemm_gpu_genai_available() and is_sm_at_least_90() and isinstance(granularity, PerRow) - # fbgemm path only supports quantizing along the last dim - and granularity.dim in (-1, len(hp_tensor.shape) - 1) and float8_dtype == torch.float8_e4m3fn and hp_value_lb is None ): @@ -477,7 +475,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data.transpose(-2, -1).contiguous(), + b_data.transpose(-2, -1), a_scale, b_scale.transpose(-2, -1), b_scale, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 1a0375f3d2..db9a5149c3 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -723,12 +723,8 @@ def get_block_size( f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}" ) return block_size - elif isinstance(granularity, PerToken): + elif isinstance(granularity, (PerRow, PerToken)): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - elif isinstance(granularity, PerRow): - block_size = [1] * len(input_shape) - block_size[granularity.dim] = input_shape[granularity.dim] - return tuple(block_size) elif isinstance(granularity, PerGroup): assert input_shape[-1] % granularity.group_size == 0, ( f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}" diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 10315d45f5..a1dc40fdd3 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -444,9 +444,7 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) # making the weight different dummy_l.weight = torch.nn.Parameter( - dummy_l.weight - + 1.0 - + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), requires_grad=False, ) quantize_(dummy_l, config) @@ -458,15 +456,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): param = l.weight param_data = param.data param_data = param_data.narrow(output_dim, start_idx, shard_size) - orig_values = param_data.qdata[0] + orig_value = param_data.qdata[0][0] loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0] - assert not torch.equal(orig_values, loaded_weight.qdata[0]) + # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] + assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) param_data.copy_(loaded_weight) # making sure param.data is updated to loaded_weight - assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0]) + assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) if hasattr(param_data, "scale"): assert torch.equal(param_data.scale, loaded_weight.scale) if hasattr(param_data, "zero_point"):