Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import run_tests

from torchao.core.config import config_from_dict, config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Expand Down Expand Up @@ -577,6 +578,44 @@ def forward(self, x):
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_bmm_weight_in_bkn_layout(self):
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
# and contigous with that shape. Since the `K` dimension is not last, we
# need to specify granularity with `PerRow(1)`.

# only support per row quantization
granularity = [PerRow(), PerRow(1)]
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)

class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"

B, M, K, N = 10, 32, 128, 256

input = torch.randn(B, M, K, dtype=dtype, device=device)
weight = torch.randn(B, K, N, dtype=dtype, device=device)
m = Model(weight).eval()
original = m(input)
quantize_(m, config, filter_fn=lambda x, fqn: True)

assert m.weight.scale.shape == (B, 1, N), (
f"unexpected scale shape {m.weight.scale.shape}"
)

quantized = m(input)
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"sizes",
Expand Down Expand Up @@ -950,6 +989,32 @@ def test_transpose(self):
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)

def test_per_row_config_before_dim(self):
"""
Test that loading a serialized config of `PerRow` before the `dim`
argument was introduced works properly
"""

# create a config with PerRow granularity
config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
)

# serialize it
config_ser = config_to_dict(config)

# manually modify the serialized config to match v1
# reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335
del config_ser["_data"]["granularity"][0]["_data"]["dim"]
del config_ser["_data"]["granularity"][1]["_data"]["dim"]
assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0
assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0

# load the modified version, verify that granularity is as expected
config_deser = config_from_dict(config_ser)
assert config_deser.granularity[0].dim == -1
assert config_deser.granularity[1].dim == -1


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
25 changes: 25 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
Expand All @@ -27,6 +28,7 @@
# TODO: remove test for utils?
from torchao.quantization.utils import (
_quantize_activation_per_token_absmax,
get_block_size,
get_group_qparams_symmetric,
groupwise_affine_dequantize_tensor_from_qparams,
groupwise_affine_quantize_tensor_from_qparams,
Expand Down Expand Up @@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self):
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)

def test_float8_rowwise_scaling_3d_weight_axis_1(self):
"""
Test scaling a weight with shape (B, K, N) and row-major memory layout
across the K dimension.
"""

B, K, N = 8, 16, 32
hp_tensor = torch.randn(B, K, N, dtype=torch.float)

granularity = PerRow(1)
block_size = get_block_size(hp_tensor.shape, granularity)
scale = _choose_scale_float8(
hp_tensor,
float8_dtype=torch.float8_e4m3fn,
block_size=block_size,
hp_value_lb=None,
hp_value_ub=None,
)
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)

assert scale.shape == (B, 1, N)
assert data.shape == (B, K, N)


if __name__ == "__main__":
unittest.main()
23 changes: 15 additions & 8 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class PerAxis(Granularity):
This granularity type calculates different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Examples:
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]
Attributes:
axis (int): The axis along which reduction is performed.
axis (int): The axis which is kept, reduction is performed across all
the other axes
"""

axis: int
Expand Down Expand Up @@ -76,12 +78,17 @@ class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
Examples:
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]
Attributes:
dim (int): The dim which is reduced across, all other dims are kept
"""

pass
dim: int = -1


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def from_hp(
and _is_fbgemm_gpu_genai_available()
and is_sm_at_least_90()
and isinstance(granularity, PerRow)
# fbgemm path only supports quantizing along the last dim
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
and float8_dtype == torch.float8_e4m3fn
and hp_value_lb is None
):
Expand Down Expand Up @@ -475,7 +477,7 @@ def _(func, types, args, kwargs):

res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data.transpose(-2, -1),
b_data.transpose(-2, -1).contiguous(),
a_scale,
b_scale.transpose(-2, -1),
b_scale,
Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,12 @@ def get_block_size(
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
)
return block_size
elif isinstance(granularity, (PerRow, PerToken)):
elif isinstance(granularity, PerToken):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerRow):
block_size = [1] * len(input_shape)
block_size[granularity.dim] = input_shape[granularity.dim]
return tuple(block_size)
elif isinstance(granularity, PerGroup):
assert input_shape[-1] % granularity.group_size == 0, (
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"
Expand Down
12 changes: 7 additions & 5 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
# making the weight different
dummy_l.weight = torch.nn.Parameter(
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
dummy_l.weight
+ 1.0
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
requires_grad=False,
)
quantize_(dummy_l, config)
Expand All @@ -456,15 +458,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
param = l.weight
param_data = param.data
param_data = param_data.narrow(output_dim, start_idx, shard_size)
orig_value = param_data.qdata[0][0]
orig_values = param_data.qdata[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the original check failed by chance, the two tensors had the same value at [0][0]. Checking the first row is more resistant to chance.

loaded_weight = dummy_l.weight
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
assert not torch.equal(orig_value, loaded_weight.qdata[0][0])
# making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0]
assert not torch.equal(orig_values, loaded_weight.qdata[0])
param_data.copy_(loaded_weight)
# making sure param.data is updated to loaded_weight
assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0])
assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0])
if hasattr(param_data, "scale"):
assert torch.equal(param_data.scale, loaded_weight.scale)
if hasattr(param_data, "zero_point"):
Expand Down
Loading