Skip to content

Commit 4199d06

Browse files
committed
Add test for fp8 semi-sparse vs. dense
Signed-off-by: Benji Beck <benjibeck@meta.com>
1 parent fc80e43 commit 4199d06

File tree

3 files changed

+75
-80
lines changed

3 files changed

+75
-80
lines changed

test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,38 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import tempfile
87
import unittest
9-
8+
from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import Float8SemiSparseTensor
9+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
10+
from torchao.float8.inference import Float8MMConfig
1011
import torch
1112
from torch.testing._internal.common_utils import (
1213
TestCase,
1314
instantiate_parametrized_tests,
1415
parametrize,
1516
run_tests,
1617
)
17-
18-
from torchao.quantization import (
19-
Float8WeightOnlyConfig,
20-
quantize_,
21-
)
22-
from torchao.quantization.utils import compute_error
2318
from torchao.sparsity.sparse_api import apply_fake_sparsity
2419
from torchao.testing.utils import skip_if_rocm
25-
from torchao.utils import torch_version_at_least
20+
from torchao.utils import is_sm_at_least_90
2621

27-
BF16_ACT_CONFIG = Float8WeightOnlyConfig(
28-
group_size=128,
29-
packing_format="cutlass_semi_sparse",
30-
)
3122

32-
33-
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
23+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100+ to run")
3424
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3525
class TestFloat8SemiSparseTensor(TestCase):
3626
def setUp(self):
3727
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
3828

3929
@skip_if_rocm("ROCm enablement in progress")
40-
@parametrize("config", [BF16_ACT_CONFIG])
4130
@parametrize(
4231
"sizes",
4332
[
4433
((128,), 256, 128),
4534
((32, 128), 512, 128),
46-
((2, 32, 128), 256, 12),
35+
((2, 32, 128), 256, 128),
4736
],
4837
)
49-
def test_linear(self, config, sizes):
38+
def test_sparse_vs_dense_fp8(self, sizes):
5039
dtype = torch.bfloat16
5140
device = "cuda"
5241

@@ -55,52 +44,20 @@ def test_linear(self, config, sizes):
5544
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
5645

5746
apply_fake_sparsity(linear)
58-
original = linear(input)
59-
quantize_(linear, config)
60-
quantized = linear(input)
61-
self.assertTrue(compute_error(original, quantized) > 20)
62-
63-
compiled_linear = torch.compile(linear)
64-
quantized_and_compiled = compiled_linear(input)
65-
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
66-
67-
@skip_if_rocm("ROCm enablement in progress")
68-
@unittest.skip("Fix later")
69-
@parametrize("config", [BF16_ACT_CONFIG])
70-
def test_to_device(self, config):
71-
for device in self.GPU_DEVICES:
72-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
73-
quantize_(linear, config)
74-
linear.to(device)
75-
76-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
77-
quantize_(linear, config)
78-
linear.to(device=device)
79-
80-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
81-
quantize_(linear, config)
82-
linear.to(device)
83-
84-
@skip_if_rocm("ROCm enablement in progress")
85-
@parametrize("config", [BF16_ACT_CONFIG])
86-
def test_module_path(self, config):
87-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88-
quantize_(linear.cuda(), config)
89-
self.assertEqual(
90-
str(type(linear.weight)),
91-
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
47+
48+
mm_config = Float8MMConfig(use_fast_accum=True)
49+
input_fp8 = Float8Tensor.from_hp(input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config)
50+
51+
weight_fp8 = Float8Tensor.from_hp(linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config)
52+
dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias)
53+
54+
weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K])
55+
sparse_output = torch.nn.functional.linear(input_fp8, weight_sparse_fp8, linear.bias)
56+
57+
torch.testing.assert_close(
58+
dense_output, sparse_output, atol=3e-1, rtol=3e-1
9259
)
9360

94-
with tempfile.NamedTemporaryFile() as f:
95-
torch.save(linear.state_dict(), f)
96-
f.seek(0)
97-
state_dict = torch.load(f)
98-
self.assertEqual(
99-
str(type(state_dict["weight"])),
100-
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
101-
)
102-
103-
10461
instantiate_parametrized_tests(TestFloat8SemiSparseTensor)
10562

10663

torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
class Float8SemiSparseTensor(TorchAOBaseTensor):
22-
tensor_data_names = ["sparse", "scale", "meta"]
22+
tensor_data_names = ["sparse", "meta", "scale"]
2323

2424
def __new__(
2525
cls,
@@ -83,29 +83,66 @@ def from_hp(
8383
implements_torch_function = Float8SemiSparseTensor.implements_torch_function
8484

8585

86-
@implements(aten.linear.default)
87-
@implements_torch_function(torch.nn.functional.linear)
86+
@implements(aten.t.default)
8887
def _(func, types, args, kwargs):
89-
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
90-
91-
input_tensor, weight_tensor, bias = (
92-
args[0],
93-
args[1],
94-
args[2] if len(args) > 2 else None,
88+
from torch.utils._python_dispatch import return_and_correct_aliasing
89+
90+
self = args[0]
91+
new = Float8SemiSparseTensor(
92+
sparse=self.sparse,
93+
meta=self.meta,
94+
scale=self.scale,
9595
)
96+
return return_and_correct_aliasing(func, args, kwargs, new)
97+
9698

97-
input = input_tensor.qdata
98-
input_scale = input_tensor.scale
99+
def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias):
100+
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
101+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
102+
103+
if isinstance(input_tensor, Float8Tensor):
104+
input = input_tensor.qdata
105+
input_scale = input_tensor.scale
106+
out_dtype = input_tensor.dtype
107+
else:
108+
input = input_tensor.qdata
109+
input_scale = input_tensor.scale
110+
out_dtype = input_tensor.dtype
111+
99112
weight = weight_tensor.sparse
100113
weight_meta = weight_tensor.meta
101114
weight_scale = weight_tensor.scale
102-
out_dtype = input_tensor.dtype
103-
104-
out = rowwise_scaled_linear_sparse_cutlass_f8f8(
115+
116+
# Reshape input_scale if needed: kernel expects scale to match input shape minus last dim
117+
# For input [B, K], scale should be [B] not [B, 1]
118+
if input_scale.dim() > input.dim() - 1:
119+
input_scale = input_scale.squeeze(-1)
120+
121+
return rowwise_scaled_linear_sparse_cutlass_f8f8(
105122
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
106123
)
107124

108-
return out
125+
126+
@implements([aten.mm.default, aten.addmm.default])
127+
def _(func, types, args, kwargs):
128+
if func == aten.addmm.default:
129+
bias, input_tensor, weight_tensor = args
130+
else: # aten.mm.default
131+
input_tensor, weight_tensor = args
132+
bias = None
133+
134+
return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias)
135+
136+
137+
@implements(aten.linear.default)
138+
@implements_torch_function(torch.nn.functional.linear)
139+
def _(func, types, args, kwargs):
140+
input_tensor, weight_tensor, bias = (
141+
args[0],
142+
args[1],
143+
args[2] if len(args) > 2 else None,
144+
)
145+
return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias)
109146

110147

111148
Float8SemiSparseTensor.__module__ = "torchao.quantization"

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,10 @@ def _(func, types, args, kwargs):
256256
args[1],
257257
args[2] if len(args) > 2 else None,
258258
)
259-
assert isinstance(weight_tensor, Float8Tensor), (
260-
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
261-
)
259+
260+
# If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it
261+
if not isinstance(weight_tensor, Float8Tensor):
262+
return NotImplemented
262263

263264
act_quant_kwargs = weight_tensor.act_quant_kwargs
264265
# quantizing activation, if `act_quant_kwargs` is specified

0 commit comments

Comments
 (0)