diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index be5f2361c3..18048a9d61 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -15,6 +15,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torchao.core.config import config_from_dict, config_to_dict from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -577,6 +578,44 @@ def forward(self, x): sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") + def test_bmm_weight_in_bkn_layout(self): + # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) + # and contigous with that shape. Since the `K` dimension is not last, we + # need to specify granularity with `PerRow(1)`. + + # only support per row quantization + granularity = [PerRow(), PerRow(1)] + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + + class Model(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + + B, M, K, N = 10, 32, 128, 256 + + input = torch.randn(B, M, K, dtype=dtype, device=device) + weight = torch.randn(B, K, N, dtype=dtype, device=device) + m = Model(weight).eval() + original = m(input) + quantize_(m, config, filter_fn=lambda x, fqn: True) + + assert m.weight.scale.shape == (B, 1, N), ( + f"unexpected scale shape {m.weight.scale.shape}" + ) + + quantized = m(input) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( "sizes", @@ -950,6 +989,32 @@ def test_transpose(self): self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) + def test_per_row_config_before_dim(self): + """ + Test that loading a serialized config of `PerRow` before the `dim` + argument was introduced works properly + """ + + # create a config with PerRow granularity + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ) + + # serialize it + config_ser = config_to_dict(config) + + # manually modify the serialized config to match v1 + # reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335 + del config_ser["_data"]["granularity"][0]["_data"]["dim"] + del config_ser["_data"]["granularity"][1]["_data"]["dim"] + assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0 + assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0 + + # load the modified version, verify that granularity is as expected + config_deser = config_from_dict(config_ser) + assert config_deser.granularity[0].dim == -1 + assert config_deser.granularity[1].dim == -1 + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 5f7895b4ea..cc6b7fff91 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch +from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -27,6 +28,7 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, + get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self): torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) + def test_float8_rowwise_scaling_3d_weight_axis_1(self): + """ + Test scaling a weight with shape (B, K, N) and row-major memory layout + across the K dimension. + """ + + B, K, N = 8, 16, 32 + hp_tensor = torch.randn(B, K, N, dtype=torch.float) + + granularity = PerRow(1) + block_size = get_block_size(hp_tensor.shape, granularity) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + hp_value_lb=None, + hp_value_ub=None, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + + assert scale.shape == (B, 1, N) + assert data.shape == (B, K, N) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index d83032d7be..97d9c07b6f 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -39,12 +39,14 @@ class PerAxis(Granularity): This granularity type calculates different quantization parameters along a specified axis of the tensor. - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. + Examples: + * input_tensor shape [A, B], axis 0 -> scale_shape [A, 1] + * input_tensor shape [A, B], axis 1 -> scale_shape [1, B] + * input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1] Attributes: - axis (int): The axis along which reduction is performed. + axis (int): The axis which is kept, reduction is performed across all + the other axes """ axis: int @@ -76,12 +78,17 @@ class PerRow(Granularity): """ Represents row-wise granularity in quantization. - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). + Examples: + * input_tensor shape [A, B], dim 0 -> scale_shape [1, B] + * input_tensor shape [A, B], dim 1 -> scale_shape [A, 1] + * input_tensor shape [A, B], dim -1 -> scale_shape [A, 1] + * input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C] + + Attributes: + dim (int): The dim which is reduced across, all other dims are kept """ - pass + dim: int = -1 @dataclass(frozen=True) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index a9c7af34b3..abb9ddc1f9 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -179,6 +179,8 @@ def from_hp( and _is_fbgemm_gpu_genai_available() and is_sm_at_least_90() and isinstance(granularity, PerRow) + # fbgemm path only supports quantizing along the last dim + and granularity.dim in (-1, len(hp_tensor.shape) - 1) and float8_dtype == torch.float8_e4m3fn and hp_value_lb is None ): @@ -475,7 +477,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data.transpose(-2, -1), + b_data.transpose(-2, -1).contiguous(), a_scale, b_scale.transpose(-2, -1), b_scale, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index db9a5149c3..1a0375f3d2 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -723,8 +723,12 @@ def get_block_size( f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}" ) return block_size - elif isinstance(granularity, (PerRow, PerToken)): + elif isinstance(granularity, PerToken): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerRow): + block_size = [1] * len(input_shape) + block_size[granularity.dim] = input_shape[granularity.dim] + return tuple(block_size) elif isinstance(granularity, PerGroup): assert input_shape[-1] % granularity.group_size == 0, ( f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}" diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a1dc40fdd3..10315d45f5 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) # making the weight different dummy_l.weight = torch.nn.Parameter( - dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + dummy_l.weight + + 1.0 + + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), requires_grad=False, ) quantize_(dummy_l, config) @@ -456,15 +458,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): param = l.weight param_data = param.data param_data = param_data.narrow(output_dim, start_idx, shard_size) - orig_value = param_data.qdata[0][0] + orig_values = param_data.qdata[0] loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] - assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) + # making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0] + assert not torch.equal(orig_values, loaded_weight.qdata[0]) param_data.copy_(loaded_weight) # making sure param.data is updated to loaded_weight - assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) + assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0]) if hasattr(param_data, "scale"): assert torch.equal(param_data.scale, loaded_weight.scale) if hasattr(param_data, "zero_point"):