Skip to content

Commit cfbe695

Browse files
authored
Enable PerRow(axis) to support axes other than -1 (#3303)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 17867e6 commit cfbe695

File tree

6 files changed

+120
-15
lines changed

6 files changed

+120
-15
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.testing._internal import common_utils
1616
from torch.testing._internal.common_utils import run_tests
1717

18+
from torchao.core.config import config_from_dict, config_to_dict
1819
from torchao.quantization import (
1920
Float8DynamicActivationFloat8WeightConfig,
2021
Float8Tensor,
@@ -634,6 +635,44 @@ def forward(self, x):
634635
sqnr = compute_error(original, quantized)
635636
self.assertTrue(sqnr > 20)
636637

638+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
639+
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
640+
def test_bmm_weight_in_bkn_layout(self):
641+
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
642+
# and contigous with that shape. Since the `K` dimension is not last, we
643+
# need to specify granularity with `PerRow(1)`.
644+
645+
# only support per row quantization
646+
granularity = [PerRow(), PerRow(1)]
647+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
648+
649+
class Model(torch.nn.Module):
650+
def __init__(self, weight):
651+
super().__init__()
652+
self.weight = weight
653+
654+
def forward(self, x):
655+
return torch.bmm(x, self.weight)
656+
657+
dtype = torch.bfloat16
658+
device = "cuda"
659+
660+
B, M, K, N = 10, 32, 128, 256
661+
662+
input = torch.randn(B, M, K, dtype=dtype, device=device)
663+
weight = torch.randn(B, K, N, dtype=dtype, device=device)
664+
m = Model(weight).eval()
665+
original = m(input)
666+
quantize_(m, config, filter_fn=lambda x, fqn: True)
667+
668+
assert m.weight.scale.shape == (B, 1, N), (
669+
f"unexpected scale shape {m.weight.scale.shape}"
670+
)
671+
672+
quantized = m(input)
673+
sqnr = compute_error(original, quantized)
674+
self.assertTrue(sqnr > 20)
675+
637676
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
638677
@common_utils.parametrize(
639678
"sizes",
@@ -1007,6 +1046,32 @@ def test_transpose(self):
10071046
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
10081047
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)
10091048

1049+
def test_per_row_config_before_dim(self):
1050+
"""
1051+
Test that loading a serialized config of `PerRow` before the `dim`
1052+
argument was introduced works properly
1053+
"""
1054+
1055+
# create a config with PerRow granularity
1056+
config = Float8DynamicActivationFloat8WeightConfig(
1057+
granularity=PerRow(),
1058+
)
1059+
1060+
# serialize it
1061+
config_ser = config_to_dict(config)
1062+
1063+
# manually modify the serialized config to match v1
1064+
# reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335
1065+
del config_ser["_data"]["granularity"][0]["_data"]["dim"]
1066+
del config_ser["_data"]["granularity"][1]["_data"]["dim"]
1067+
assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0
1068+
assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0
1069+
1070+
# load the modified version, verify that granularity is as expected
1071+
config_deser = config_from_dict(config_ser)
1072+
assert config_deser.granularity[0].dim == -1
1073+
assert config_deser.granularity[1].dim == -1
1074+
10101075

10111076
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
10121077

test/quantization/test_quant_primitives.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.quantization.granularity import PerRow
1314
from torchao.quantization.quant_primitives import (
1415
MappingType,
1516
ZeroPointDomain,
@@ -27,6 +28,7 @@
2728
# TODO: remove test for utils?
2829
from torchao.quantization.utils import (
2930
_quantize_activation_per_token_absmax,
31+
get_block_size,
3032
get_group_qparams_symmetric,
3133
groupwise_affine_dequantize_tensor_from_qparams,
3234
groupwise_affine_quantize_tensor_from_qparams,
@@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self):
844846
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
845847
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
846848

849+
def test_float8_rowwise_scaling_3d_weight_axis_1(self):
850+
"""
851+
Test scaling a weight with shape (B, K, N) and row-major memory layout
852+
across the K dimension.
853+
"""
854+
855+
B, K, N = 8, 16, 32
856+
hp_tensor = torch.randn(B, K, N, dtype=torch.float)
857+
858+
granularity = PerRow(1)
859+
block_size = get_block_size(hp_tensor.shape, granularity)
860+
scale = _choose_scale_float8(
861+
hp_tensor,
862+
float8_dtype=torch.float8_e4m3fn,
863+
block_size=block_size,
864+
hp_value_lb=None,
865+
hp_value_ub=None,
866+
)
867+
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
868+
869+
assert scale.shape == (B, 1, N)
870+
assert data.shape == (B, K, N)
871+
847872

848873
if __name__ == "__main__":
849874
unittest.main()

torchao/quantization/granularity.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class PerAxis(Granularity):
3939
This granularity type calculates different quantization parameters
4040
along a specified axis of the tensor.
4141
42-
For example if the input tensor is shape [8, 16] and axis=0, then
43-
the quantization parameters are calculated for each row of the tensor.
44-
Giving a total of 8 quantization parameters.
42+
Examples:
43+
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
44+
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
45+
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]
4546
4647
Attributes:
47-
axis (int): The axis along which reduction is performed.
48+
axis (int): The axis which is kept, reduction is performed across all
49+
the other axes
4850
"""
4951

5052
axis: int
@@ -76,12 +78,17 @@ class PerRow(Granularity):
7678
"""
7779
Represents row-wise granularity in quantization.
7880
79-
This is a special case of per-axis quantization and is unique to Float8 matmuls
80-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
81-
is quantized with a block_size of (1, weight.shape[1]).
81+
Examples:
82+
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
83+
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
84+
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
85+
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]
86+
87+
Attributes:
88+
dim (int): The dim which is reduced across, all other dims are kept
8289
"""
8390

84-
pass
91+
dim: int = -1
8592

8693

8794
@dataclass(frozen=True)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def from_hp(
179179
and _is_fbgemm_gpu_genai_available()
180180
and is_sm_at_least_90()
181181
and isinstance(granularity, PerRow)
182+
# fbgemm path only supports quantizing along the last dim
183+
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
182184
and float8_dtype == torch.float8_e4m3fn
183185
and hp_value_lb is None
184186
):
@@ -475,7 +477,7 @@ def _(func, types, args, kwargs):
475477

476478
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
477479
a_data,
478-
b_data.transpose(-2, -1),
480+
b_data.transpose(-2, -1).contiguous(),
479481
a_scale,
480482
b_scale.transpose(-2, -1),
481483
b_scale,

torchao/quantization/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,12 @@ def get_block_size(
723723
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
724724
)
725725
return block_size
726-
elif isinstance(granularity, (PerRow, PerToken)):
726+
elif isinstance(granularity, PerToken):
727727
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
728+
elif isinstance(granularity, PerRow):
729+
block_size = [1] * len(input_shape)
730+
block_size[granularity.dim] = input_shape[granularity.dim]
731+
return tuple(block_size)
728732
elif isinstance(granularity, PerGroup):
729733
assert input_shape[-1] % granularity.group_size == 0, (
730734
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"

torchao/testing/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
444444
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
445445
# making the weight different
446446
dummy_l.weight = torch.nn.Parameter(
447-
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
447+
dummy_l.weight
448+
+ 1.0
449+
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
448450
requires_grad=False,
449451
)
450452
quantize_(dummy_l, config)
@@ -456,15 +458,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
456458
param = l.weight
457459
param_data = param.data
458460
param_data = param_data.narrow(output_dim, start_idx, shard_size)
459-
orig_value = param_data.qdata[0][0]
461+
orig_values = param_data.qdata[0]
460462
loaded_weight = dummy_l.weight
461463
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
462464

463-
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
464-
assert not torch.equal(orig_value, loaded_weight.qdata[0][0])
465+
# making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0]
466+
assert not torch.equal(orig_values, loaded_weight.qdata[0])
465467
param_data.copy_(loaded_weight)
466468
# making sure param.data is updated to loaded_weight
467-
assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0])
469+
assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0])
468470
if hasattr(param_data, "scale"):
469471
assert torch.equal(param_data.scale, loaded_weight.scale)
470472
if hasattr(param_data, "zero_point"):

0 commit comments

Comments
 (0)