Skip to content

Commit 0125ffc

Browse files
authored
Revert "Enable PerRow(axis) to support axes other than -1 (#3303)"
This reverts commit cfbe695.
1 parent e8c4d09 commit 0125ffc

File tree

6 files changed

+15
-120
lines changed

6 files changed

+15
-120
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torch.testing._internal import common_utils
1616
from torch.testing._internal.common_utils import run_tests
1717

18-
from torchao.core.config import config_from_dict, config_to_dict
1918
from torchao.quantization import (
2019
Float8DynamicActivationFloat8WeightConfig,
2120
Float8Tensor,
@@ -635,44 +634,6 @@ def forward(self, x):
635634
sqnr = compute_error(original, quantized)
636635
self.assertTrue(sqnr > 20)
637636

638-
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
639-
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
640-
def test_bmm_weight_in_bkn_layout(self):
641-
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
642-
# and contigous with that shape. Since the `K` dimension is not last, we
643-
# need to specify granularity with `PerRow(1)`.
644-
645-
# only support per row quantization
646-
granularity = [PerRow(), PerRow(1)]
647-
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
648-
649-
class Model(torch.nn.Module):
650-
def __init__(self, weight):
651-
super().__init__()
652-
self.weight = weight
653-
654-
def forward(self, x):
655-
return torch.bmm(x, self.weight)
656-
657-
dtype = torch.bfloat16
658-
device = "cuda"
659-
660-
B, M, K, N = 10, 32, 128, 256
661-
662-
input = torch.randn(B, M, K, dtype=dtype, device=device)
663-
weight = torch.randn(B, K, N, dtype=dtype, device=device)
664-
m = Model(weight).eval()
665-
original = m(input)
666-
quantize_(m, config, filter_fn=lambda x, fqn: True)
667-
668-
assert m.weight.scale.shape == (B, 1, N), (
669-
f"unexpected scale shape {m.weight.scale.shape}"
670-
)
671-
672-
quantized = m(input)
673-
sqnr = compute_error(original, quantized)
674-
self.assertTrue(sqnr > 20)
675-
676637
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
677638
@common_utils.parametrize(
678639
"sizes",
@@ -1046,32 +1007,6 @@ def test_transpose(self):
10461007
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
10471008
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)
10481009

1049-
def test_per_row_config_before_dim(self):
1050-
"""
1051-
Test that loading a serialized config of `PerRow` before the `dim`
1052-
argument was introduced works properly
1053-
"""
1054-
1055-
# create a config with PerRow granularity
1056-
config = Float8DynamicActivationFloat8WeightConfig(
1057-
granularity=PerRow(),
1058-
)
1059-
1060-
# serialize it
1061-
config_ser = config_to_dict(config)
1062-
1063-
# manually modify the serialized config to match v1
1064-
# reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335
1065-
del config_ser["_data"]["granularity"][0]["_data"]["dim"]
1066-
del config_ser["_data"]["granularity"][1]["_data"]["dim"]
1067-
assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0
1068-
assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0
1069-
1070-
# load the modified version, verify that granularity is as expected
1071-
config_deser = config_from_dict(config_ser)
1072-
assert config_deser.granularity[0].dim == -1
1073-
assert config_deser.granularity[1].dim == -1
1074-
10751010

10761011
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
10771012

test/quantization/test_quant_primitives.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
from torchao.quantization.granularity import PerRow
1413
from torchao.quantization.quant_primitives import (
1514
MappingType,
1615
ZeroPointDomain,
@@ -28,7 +27,6 @@
2827
# TODO: remove test for utils?
2928
from torchao.quantization.utils import (
3029
_quantize_activation_per_token_absmax,
31-
get_block_size,
3230
get_group_qparams_symmetric,
3331
groupwise_affine_dequantize_tensor_from_qparams,
3432
groupwise_affine_quantize_tensor_from_qparams,
@@ -846,29 +844,6 @@ def test_float8_blockwise_scaling(self):
846844
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
847845
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
848846

849-
def test_float8_rowwise_scaling_3d_weight_axis_1(self):
850-
"""
851-
Test scaling a weight with shape (B, K, N) and row-major memory layout
852-
across the K dimension.
853-
"""
854-
855-
B, K, N = 8, 16, 32
856-
hp_tensor = torch.randn(B, K, N, dtype=torch.float)
857-
858-
granularity = PerRow(1)
859-
block_size = get_block_size(hp_tensor.shape, granularity)
860-
scale = _choose_scale_float8(
861-
hp_tensor,
862-
float8_dtype=torch.float8_e4m3fn,
863-
block_size=block_size,
864-
hp_value_lb=None,
865-
hp_value_ub=None,
866-
)
867-
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
868-
869-
assert scale.shape == (B, 1, N)
870-
assert data.shape == (B, K, N)
871-
872847

873848
if __name__ == "__main__":
874849
unittest.main()

torchao/quantization/granularity.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,12 @@ class PerAxis(Granularity):
3939
This granularity type calculates different quantization parameters
4040
along a specified axis of the tensor.
4141
42-
Examples:
43-
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
44-
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
45-
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]
42+
For example if the input tensor is shape [8, 16] and axis=0, then
43+
the quantization parameters are calculated for each row of the tensor.
44+
Giving a total of 8 quantization parameters.
4645
4746
Attributes:
48-
axis (int): The axis which is kept, reduction is performed across all
49-
the other axes
47+
axis (int): The axis along which reduction is performed.
5048
"""
5149

5250
axis: int
@@ -78,17 +76,12 @@ class PerRow(Granularity):
7876
"""
7977
Represents row-wise granularity in quantization.
8078
81-
Examples:
82-
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
83-
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
84-
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
85-
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]
86-
87-
Attributes:
88-
dim (int): The dim which is reduced across, all other dims are kept
79+
This is a special case of per-axis quantization and is unique to Float8 matmuls
80+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
81+
is quantized with a block_size of (1, weight.shape[1]).
8982
"""
9083

91-
dim: int = -1
84+
pass
9285

9386

9487
@dataclass(frozen=True)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,6 @@ def from_hp(
179179
and _is_fbgemm_gpu_genai_available()
180180
and is_sm_at_least_90()
181181
and isinstance(granularity, PerRow)
182-
# fbgemm path only supports quantizing along the last dim
183-
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
184182
and float8_dtype == torch.float8_e4m3fn
185183
and hp_value_lb is None
186184
):
@@ -477,7 +475,7 @@ def _(func, types, args, kwargs):
477475

478476
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
479477
a_data,
480-
b_data.transpose(-2, -1).contiguous(),
478+
b_data.transpose(-2, -1),
481479
a_scale,
482480
b_scale.transpose(-2, -1),
483481
b_scale,

torchao/quantization/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,12 +723,8 @@ def get_block_size(
723723
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
724724
)
725725
return block_size
726-
elif isinstance(granularity, PerToken):
726+
elif isinstance(granularity, (PerRow, PerToken)):
727727
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
728-
elif isinstance(granularity, PerRow):
729-
block_size = [1] * len(input_shape)
730-
block_size[granularity.dim] = input_shape[granularity.dim]
731-
return tuple(block_size)
732728
elif isinstance(granularity, PerGroup):
733729
assert input_shape[-1] % granularity.group_size == 0, (
734730
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"

torchao/testing/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,7 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
444444
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
445445
# making the weight different
446446
dummy_l.weight = torch.nn.Parameter(
447-
dummy_l.weight
448-
+ 1.0
449-
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
447+
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
450448
requires_grad=False,
451449
)
452450
quantize_(dummy_l, config)
@@ -458,15 +456,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
458456
param = l.weight
459457
param_data = param.data
460458
param_data = param_data.narrow(output_dim, start_idx, shard_size)
461-
orig_values = param_data.qdata[0]
459+
orig_value = param_data.qdata[0][0]
462460
loaded_weight = dummy_l.weight
463461
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
464462

465-
# making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0]
466-
assert not torch.equal(orig_values, loaded_weight.qdata[0])
463+
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
464+
assert not torch.equal(orig_value, loaded_weight.qdata[0][0])
467465
param_data.copy_(loaded_weight)
468466
# making sure param.data is updated to loaded_weight
469-
assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0])
467+
assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0])
470468
if hasattr(param_data, "scale"):
471469
assert torch.equal(param_data.scale, loaded_weight.scale)
472470
if hasattr(param_data, "zero_point"):

0 commit comments

Comments
 (0)