From d76d3fff20967b0a33130fc682b8999c8f66aa44 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 5 Nov 2025 12:49:53 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 18 ++++++++++-------- .../workflows/float8/float8_tensor.py | 13 +++++++------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4871b48849..26afb02aaa 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -444,25 +444,27 @@ def test_bmm(self): # only support per row quantization config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) - class M(torch.nn.Module): + class Model(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): - return torch.bmm(x, self.weight) + return torch.bmm(x, self.weight.transpose(-2, -1)) dtype = torch.bfloat16 device = "cuda" - input = torch.randn(10, 32, 128, dtype=dtype, device=device) - weight = torch.randn(10, 128, 256, dtype=dtype, device=device) - m = M(weight).eval() + + B, M, K, N = 10, 32, 128, 256 + + input = torch.randn(B, M, K, dtype=dtype, device=device) + weight = torch.randn(B, N, K, dtype=dtype, device=device) + m = Model(weight).eval() original = m(input) - # we need to transpose the weight first for bmm - m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) quantize_(m, config, filter_fn=lambda x, fqn: True) quantized = m(input) - self.assertTrue(compute_error(original, quantized) > 20) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 3581cb619c..a5e083d4cc 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -422,24 +422,25 @@ def _(func, types, args, kwargs): a_scale = input_tensor.scale b_data = weight_tensor.qdata - b_scale = weight_tensor.scale.squeeze(-1) - assert b_data.is_contiguous(), "weight for bmm must be contiguous" + b_scale = weight_tensor.scale assert ( - all(x == 1 for x in weight_tensor.block_size[:-1]) - and weight_tensor.block_size[-1] == weight_tensor.shape[-1] + weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.block_size[2] == 1 ), "bmm only works for per row weight quantization" assert ( all(x == 1 for x in input_tensor.block_size[:-1]) and input_tensor.block_size[-1] == input_tensor.shape[-1] ), "bmm only works for per row activation quantization" - orig_out_features = b_data.shape[-2] + orig_out_features = b_data.shape[-1] res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data, + b_data.transpose(-2, -1), a_scale, + b_scale.transpose(-2, -1), b_scale, ) res = res.reshape(*orig_act_size[:-1], orig_out_features)