From d76d3fff20967b0a33130fc682b8999c8f66aa44 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 5 Nov 2025 12:49:53 -0800 Subject: [PATCH 01/10] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 18 ++++++++++-------- .../workflows/float8/float8_tensor.py | 13 +++++++------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4871b48849..26afb02aaa 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -444,25 +444,27 @@ def test_bmm(self): # only support per row quantization config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) - class M(torch.nn.Module): + class Model(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): - return torch.bmm(x, self.weight) + return torch.bmm(x, self.weight.transpose(-2, -1)) dtype = torch.bfloat16 device = "cuda" - input = torch.randn(10, 32, 128, dtype=dtype, device=device) - weight = torch.randn(10, 128, 256, dtype=dtype, device=device) - m = M(weight).eval() + + B, M, K, N = 10, 32, 128, 256 + + input = torch.randn(B, M, K, dtype=dtype, device=device) + weight = torch.randn(B, N, K, dtype=dtype, device=device) + m = Model(weight).eval() original = m(input) - # we need to transpose the weight first for bmm - m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) quantize_(m, config, filter_fn=lambda x, fqn: True) quantized = m(input) - self.assertTrue(compute_error(original, quantized) > 20) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 3581cb619c..a5e083d4cc 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -422,24 +422,25 @@ def _(func, types, args, kwargs): a_scale = input_tensor.scale b_data = weight_tensor.qdata - b_scale = weight_tensor.scale.squeeze(-1) - assert b_data.is_contiguous(), "weight for bmm must be contiguous" + b_scale = weight_tensor.scale assert ( - all(x == 1 for x in weight_tensor.block_size[:-1]) - and weight_tensor.block_size[-1] == weight_tensor.shape[-1] + weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.block_size[2] == 1 ), "bmm only works for per row weight quantization" assert ( all(x == 1 for x in input_tensor.block_size[:-1]) and input_tensor.block_size[-1] == input_tensor.shape[-1] ), "bmm only works for per row activation quantization" - orig_out_features = b_data.shape[-2] + orig_out_features = b_data.shape[-1] res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data, + b_data.transpose(-2, -1), a_scale, + b_scale.transpose(-2, -1), b_scale, ) res = res.reshape(*orig_act_size[:-1], orig_out_features) From 763a7091aa54e02614c0c22ffc57327ed569912b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 06:43:34 -0800 Subject: [PATCH 02/10] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 34 +++++++++++++++++++ torchao/float8/inference.py | 15 ++++++-- .../workflows/float8/float8_tensor.py | 2 ++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 959c83c282..117b1de4e3 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -18,6 +18,7 @@ from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, + PerAxis, PerBlock, PerRow, PerTensor, @@ -466,6 +467,39 @@ def forward(self, x): sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") + def test_bmm_weight_in_bkn_layout(self): + # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) + # and contigous with that shape. Since the `K` dimension is not last, we + # need to specify granularity with `PerAxis(1)`. + + # only support per row quantization + granularity = [PerRow(), PerAxis(1)] + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + + class Model(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + + B, M, K, N = 10, 32, 128, 256 + + input = torch.randn(B, M, K, dtype=dtype, device=device) + weight = torch.randn(B, K, N, dtype=dtype, device=device) + m = Model(weight).eval() + original = m(input) + quantize_(m, config, filter_fn=lambda x, fqn: True) + quantized = m(input) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( "sizes", diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..37966c22ef 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -15,6 +15,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerAxis, PerBlock, PerRow, PerTensor, @@ -247,13 +248,21 @@ def _normalize_granularity( granularity[1], PerTensor ) is_per_row = isinstance(granularity[0], PerRow) and isinstance( - granularity[1], PerRow + granularity[1], (PerRow, PerAxis) ) is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity) if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128): raise ValueError(f"Unsupported granularity types: {granularity}.") - if not isinstance(granularity[0], type(granularity[1])): + + a_w_granularities_match = ( + # direct match + isinstance(granularity[0], type(granularity[1])) + # PerAxis is a more general version of PerRow + or (isinstance(granularity[0], PerRow) and isinstance(granularity[1], PerAxis)) + ) + + if not a_w_granularities_match: raise ValueError( f"Different granularities for activation and weight are not supported: {granularity}." ) @@ -280,7 +289,7 @@ def _check_hardware_support( granularities[1], PerTensor ) is_per_row = isinstance(granularities[0], PerRow) and isinstance( - granularities[1], PerRow + granularities[1], (PerRow, PerAxis) ) is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index a5e083d4cc..ad423b640c 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -423,6 +423,8 @@ def _(func, types, args, kwargs): b_data = weight_tensor.qdata b_scale = weight_tensor.scale + print('a', a_data.shape, a_scale.shape, input_tensor.block_size) + print('b', b_data.shape, b_scale.shape, weight_tensor.block_size) assert ( weight_tensor.block_size[0] == 1 From 6143910385b2c226c82a5548bafc678deee5cbd1 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 07:33:31 -0800 Subject: [PATCH 03/10] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 12 ++++++--- test/quantization/test_quant_primitives.py | 25 +++++++++++++++++++ torchao/float8/inference.py | 15 +++-------- torchao/quantization/granularity.py | 18 ++++++++----- .../workflows/float8/float8_tensor.py | 5 ++-- torchao/quantization/utils.py | 6 ++++- 6 files changed, 55 insertions(+), 26 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 117b1de4e3..26a433f93e 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -18,7 +18,6 @@ from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, - PerAxis, PerBlock, PerRow, PerTensor, @@ -472,10 +471,10 @@ def forward(self, x): def test_bmm_weight_in_bkn_layout(self): # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) # and contigous with that shape. Since the `K` dimension is not last, we - # need to specify granularity with `PerAxis(1)`. - + # need to specify granularity with `PerRow(1)`. + # only support per row quantization - granularity = [PerRow(), PerAxis(1)] + granularity = [PerRow(), PerRow(1)] config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) class Model(torch.nn.Module): @@ -496,6 +495,11 @@ def forward(self, x): m = Model(weight).eval() original = m(input) quantize_(m, config, filter_fn=lambda x, fqn: True) + + assert m.weight.scale.shape == (B, 1, N), ( + f"unexpected scale shape {m.weight.scale.shape}" + ) + quantized = m(input) sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 5f7895b4ea..cc6b7fff91 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch +from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -27,6 +28,7 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, + get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self): torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) + def test_float8_rowwise_scaling_3d_weight_axis_1(self): + """ + Test scaling a weight with shape (B, K, N) and row-major memory layout + across the K dimension. + """ + + B, K, N = 8, 16, 32 + hp_tensor = torch.randn(B, K, N, dtype=torch.float) + + granularity = PerRow(1) + block_size = get_block_size(hp_tensor.shape, granularity) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + hp_value_lb=None, + hp_value_ub=None, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + + assert scale.shape == (B, 1, N) + assert data.shape == (B, K, N) + if __name__ == "__main__": unittest.main() diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 37966c22ef..212df9c5db 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -15,7 +15,6 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( - PerAxis, PerBlock, PerRow, PerTensor, @@ -248,21 +247,13 @@ def _normalize_granularity( granularity[1], PerTensor ) is_per_row = isinstance(granularity[0], PerRow) and isinstance( - granularity[1], (PerRow, PerAxis) + granularity[1], PerRow ) is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity) if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128): raise ValueError(f"Unsupported granularity types: {granularity}.") - - a_w_granularities_match = ( - # direct match - isinstance(granularity[0], type(granularity[1])) - # PerAxis is a more general version of PerRow - or (isinstance(granularity[0], PerRow) and isinstance(granularity[1], PerAxis)) - ) - - if not a_w_granularities_match: + if not isinstance(granularity[0], type(granularity[1])): raise ValueError( f"Different granularities for activation and weight are not supported: {granularity}." ) @@ -289,7 +280,7 @@ def _check_hardware_support( granularities[1], PerTensor ) is_per_row = isinstance(granularities[0], PerRow) and isinstance( - granularities[1], (PerRow, PerAxis) + granularities[1], PerRow ) is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index d83032d7be..5f6439f0f1 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -39,12 +39,14 @@ class PerAxis(Granularity): This granularity type calculates different quantization parameters along a specified axis of the tensor. - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. + Examples: + * input_tensor shape [A, B], axis 0 -> scale_shape [A, 1] + * input_tensor shape [A, B], axis 1 -> scale_shape [1, B] + * input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1] Attributes: - axis (int): The axis along which reduction is performed. + axis (int): The axis which is kept, reduction is performed across all + the other axes """ axis: int @@ -76,12 +78,16 @@ class PerRow(Granularity): """ Represents row-wise granularity in quantization. - This is a special case of per-axis quantization and is unique to Float8 matmuls + For 2D tensors, this is a special case of per-axis quantization and is unique to Float8 matmuls where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight is quantized with a block_size of (1, weight.shape[1]). + + TODO(before land): modify docblock for new axis argument """ - pass + # TODO(before land): any BC concerns with loading old checkpoints + # serialized without this arg? investigate this + axis: int = -1 @dataclass(frozen=True) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index ad423b640c..b59119c340 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -180,6 +180,7 @@ def from_hp( and _is_fbgemm_gpu_genai_available() and is_sm_at_least_90() and isinstance(granularity, PerRow) + and granularity.axis in (-1, len(hp_tensor.shape)) and float8_dtype == torch.float8_e4m3fn and hp_value_lb is None ): @@ -423,8 +424,6 @@ def _(func, types, args, kwargs): b_data = weight_tensor.qdata b_scale = weight_tensor.scale - print('a', a_data.shape, a_scale.shape, input_tensor.block_size) - print('b', b_data.shape, b_scale.shape, weight_tensor.block_size) assert ( weight_tensor.block_size[0] == 1 @@ -440,7 +439,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data.transpose(-2, -1), + b_data.transpose(-2, -1).contiguous(), a_scale, b_scale.transpose(-2, -1), b_scale, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index db9a5149c3..2f4d82a891 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -723,8 +723,12 @@ def get_block_size( f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}" ) return block_size - elif isinstance(granularity, (PerRow, PerToken)): + elif isinstance(granularity, PerToken): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerRow): + block_size = [1] * len(input_shape) + block_size[granularity.axis] = input_shape[granularity.axis] + return tuple(block_size) elif isinstance(granularity, PerGroup): assert input_shape[-1] % granularity.group_size == 0, ( f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}" From 0073d60f84ac6395fc6a9fb27d5c64b6372eebbf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 13:24:21 -0800 Subject: [PATCH 04/10] Update [ghstack-poisoned] --- torchao/quantization/granularity.py | 2 +- .../quantization/quantize_/workflows/float8/float8_tensor.py | 3 ++- torchao/quantization/utils.py | 2 +- torchao/testing/utils.py | 4 +++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 5f6439f0f1..deb983ee7c 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -87,7 +87,7 @@ class PerRow(Granularity): # TODO(before land): any BC concerns with loading old checkpoints # serialized without this arg? investigate this - axis: int = -1 + dim: int = -1 @dataclass(frozen=True) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index b59119c340..2108fa8fa6 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -180,7 +180,8 @@ def from_hp( and _is_fbgemm_gpu_genai_available() and is_sm_at_least_90() and isinstance(granularity, PerRow) - and granularity.axis in (-1, len(hp_tensor.shape)) + # fbgemm path only supports quantizing along the last dim + and granularity.dim in (-1, len(hp_tensor.shape) - 1) and float8_dtype == torch.float8_e4m3fn and hp_value_lb is None ): diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 2f4d82a891..1a0375f3d2 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -727,7 +727,7 @@ def get_block_size( return (1,) * (len(input_shape) - 1) + (input_shape[-1],) elif isinstance(granularity, PerRow): block_size = [1] * len(input_shape) - block_size[granularity.axis] = input_shape[granularity.axis] + block_size[granularity.dim] = input_shape[granularity.dim] return tuple(block_size) elif isinstance(granularity, PerGroup): assert input_shape[-1] % granularity.group_size == 0, ( diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a1dc40fdd3..f0ec9c114c 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) # making the weight different dummy_l.weight = torch.nn.Parameter( - dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + dummy_l.weight + + 1.0 + + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), requires_grad=False, ) quantize_(dummy_l, config) From a99196fe3ffb255e952307a28c3c3d5c34c40349 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 13:38:09 -0800 Subject: [PATCH 05/10] Update [ghstack-poisoned] --- torchao/quantization/granularity.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index deb983ee7c..f584a081cc 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -78,11 +78,14 @@ class PerRow(Granularity): """ Represents row-wise granularity in quantization. - For 2D tensors, this is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). + Examples: + * input_tensor shape [A, B], dim 0 -> scale_shape [1, B] + * input_tensor shape [A, B], dim 1 -> scale_shape [A, 1] + * input_tensor shape [A, B], dim -1 -> scale_shape [A, 1] + * input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C] - TODO(before land): modify docblock for new axis argument + Attributes: + dim (int): The dim which is reduced across, all other dims are kept """ # TODO(before land): any BC concerns with loading old checkpoints From 84b06d1d074162514e23e1d9f4da848c3b70d9a3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 16:48:39 -0800 Subject: [PATCH 06/10] Update [ghstack-poisoned] --- torchao/testing/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index f0ec9c114c..ae062f7161 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -462,6 +462,13 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # debugging CI failures + # TODO(before land): remove this + if not torch.equal(orig_value, loaded_weight.qdata[0][0]): + print("param_data.qdata", param_data.qdata) + print("orig_value", orig_value) + print("loaded_weight.qdata", loaded_weight.qdata) + # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) param_data.copy_(loaded_weight) From d4ce3558daf57b4aa343ac5bf71b51f3cf725cb8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 6 Nov 2025 17:17:25 -0800 Subject: [PATCH 07/10] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 26a433f93e..d5353d1ed3 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -15,6 +15,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torchao.core.config import config_from_dict, config_to_dict from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -845,6 +846,32 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): self.assertEqual(sliced_dequantized, sliced_original) + def test_per_row_config_before_dim(self): + """ + Test that loading a serialized config of `PerRow` before the `dim` + argument was introduced works properly + """ + + # create a config with PerRow granularity + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ) + + # serialize it + config_ser = config_to_dict(config) + + # manually modify the serialized config to match v1 + # reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335 + del config_ser["_data"]["granularity"][0]["_data"]["dim"] + del config_ser["_data"]["granularity"][1]["_data"]["dim"] + assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0 + assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0 + + # load the modified version, verify that granularity is as expected + config_deser = config_from_dict(config_ser) + assert config_deser.granularity[0].dim == -1 + assert config_deser.granularity[1].dim == -1 + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) From b49280d4c0cb1b055c660112fb9fdfc437f17d62 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 7 Nov 2025 03:24:13 -0800 Subject: [PATCH 08/10] Update [ghstack-poisoned] --- torchao/testing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index ae062f7161..d79602ce9d 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -464,7 +464,7 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): # debugging CI failures # TODO(before land): remove this - if not torch.equal(orig_value, loaded_weight.qdata[0][0]): + if torch.equal(orig_value, loaded_weight.qdata[0][0]): print("param_data.qdata", param_data.qdata) print("orig_value", orig_value) print("loaded_weight.qdata", loaded_weight.qdata) From ec28bf5e0e5a2badc696f27bd44c9fb64a0a12d9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 7 Nov 2025 04:25:07 -0800 Subject: [PATCH 09/10] Update [ghstack-poisoned] --- torchao/testing/utils.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index d79602ce9d..10315d45f5 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -458,22 +458,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): param = l.weight param_data = param.data param_data = param_data.narrow(output_dim, start_idx, shard_size) - orig_value = param_data.qdata[0][0] + orig_values = param_data.qdata[0] loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # debugging CI failures - # TODO(before land): remove this - if torch.equal(orig_value, loaded_weight.qdata[0][0]): - print("param_data.qdata", param_data.qdata) - print("orig_value", orig_value) - print("loaded_weight.qdata", loaded_weight.qdata) - - # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] - assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) + # making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0] + assert not torch.equal(orig_values, loaded_weight.qdata[0]) param_data.copy_(loaded_weight) # making sure param.data is updated to loaded_weight - assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) + assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0]) if hasattr(param_data, "scale"): assert torch.equal(param_data.scale, loaded_weight.scale) if hasattr(param_data, "zero_point"): From a6324b0d280e66fd8f51695c052749d219611087 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 7 Nov 2025 06:44:03 -0800 Subject: [PATCH 10/10] Update [ghstack-poisoned] --- torchao/quantization/granularity.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index f584a081cc..97d9c07b6f 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -88,8 +88,6 @@ class PerRow(Granularity): dim (int): The dim which is reduced across, all other dims are kept """ - # TODO(before land): any BC concerns with loading old checkpoints - # serialized without this arg? investigate this dim: int = -1