Skip to content

Commit fc80e43

Browse files
committed
Implement packing and linear
Signed-off-by: Benji Beck <benjibeck@meta.com>
1 parent 9b2cfe8 commit fc80e43

File tree

7 files changed

+231
-56
lines changed

7 files changed

+231
-56
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Float8WeightOnlyConfig,
20+
quantize_,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.sparsity.sparse_api import apply_fake_sparsity
24+
from torchao.testing.utils import skip_if_rocm
25+
from torchao.utils import torch_version_at_least
26+
27+
BF16_ACT_CONFIG = Float8WeightOnlyConfig(
28+
group_size=128,
29+
packing_format="cutlass_semi_sparse",
30+
)
31+
32+
33+
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
34+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
35+
class TestFloat8SemiSparseTensor(TestCase):
36+
def setUp(self):
37+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
38+
39+
@skip_if_rocm("ROCm enablement in progress")
40+
@parametrize("config", [BF16_ACT_CONFIG])
41+
@parametrize(
42+
"sizes",
43+
[
44+
((128,), 256, 128),
45+
((32, 128), 512, 128),
46+
((2, 32, 128), 256, 12),
47+
],
48+
)
49+
def test_linear(self, config, sizes):
50+
dtype = torch.bfloat16
51+
device = "cuda"
52+
53+
M, N, K = sizes
54+
input = torch.randn(*M, K, dtype=dtype, device=device)
55+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
56+
57+
apply_fake_sparsity(linear)
58+
original = linear(input)
59+
quantize_(linear, config)
60+
quantized = linear(input)
61+
self.assertTrue(compute_error(original, quantized) > 20)
62+
63+
compiled_linear = torch.compile(linear)
64+
quantized_and_compiled = compiled_linear(input)
65+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
66+
67+
@skip_if_rocm("ROCm enablement in progress")
68+
@unittest.skip("Fix later")
69+
@parametrize("config", [BF16_ACT_CONFIG])
70+
def test_to_device(self, config):
71+
for device in self.GPU_DEVICES:
72+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
73+
quantize_(linear, config)
74+
linear.to(device)
75+
76+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
77+
quantize_(linear, config)
78+
linear.to(device=device)
79+
80+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
81+
quantize_(linear, config)
82+
linear.to(device)
83+
84+
@skip_if_rocm("ROCm enablement in progress")
85+
@parametrize("config", [BF16_ACT_CONFIG])
86+
def test_module_path(self, config):
87+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88+
quantize_(linear.cuda(), config)
89+
self.assertEqual(
90+
str(type(linear.weight)),
91+
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
92+
)
93+
94+
with tempfile.NamedTemporaryFile() as f:
95+
torch.save(linear.state_dict(), f)
96+
f.seek(0)
97+
state_dict = torch.load(f)
98+
self.assertEqual(
99+
str(type(state_dict["weight"])),
100+
"<class 'torchao.quantization.Float8SemiSparseTensor'>",
101+
)
102+
103+
104+
instantiate_parametrized_tests(TestFloat8SemiSparseTensor)
105+
106+
107+
if __name__ == "__main__":
108+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
quantize_affine,
7979
)
8080
from .quantize_.workflows import (
81+
Float8SemiSparseTensor,
8182
Float8Tensor,
8283
Int4MarlinSparseTensor,
8384
Int4OpaqueTensor,
@@ -148,6 +149,7 @@
148149
"Int4TilePackedTo4dTensor",
149150
"Float8Tensor",
150151
"Int4OpaqueTensor",
152+
"Float8SemiSparseTensor",
151153
# smooth quant - subject to change
152154
"get_scale",
153155
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,7 @@ def _int8_weight_only_quantize_tensor(weight, config):
13361336
if group_size is None:
13371337
group_size = weight.shape[-1]
13381338
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
1339+
# todo: support fp8 semi-sparse
13391340
new_weight = to_affine_quantized_intx(
13401341
weight,
13411342
mapping_type,
@@ -1584,6 +1585,7 @@ class Float8WeightOnlyConfig(AOBaseConfig):
15841585
weight_dtype: torch.dtype = e4m3_dtype
15851586
set_inductor_config: bool = True
15861587
version: int = 2
1588+
# todo: add packing format
15871589

15881590
def __post_init__(self):
15891591
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ class PackingFormat(str, Enum):
3232
needed for the rest of the system to understand the specific format that's adopted.
3333
"""
3434
OPAQUE = "opaque"
35+
# todo: add semi-sparse

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from .float8.float8_semi_sparse_tensor import (
2+
Float8SemiSparseTensor,
3+
)
14
from .float8.float8_tensor import (
25
Float8Tensor,
36
QuantizeTensorToFloat8Kwargs,
@@ -38,6 +41,7 @@
3841
"Int4PlainInt32Tensor",
3942
"Int4TilePackedTo4dTensor",
4043
"Float8Tensor",
44+
"Float8SemiSparseTensor",
4145
"QuantizeTensorToFloat8Kwargs",
4246
"Int4OpaqueTensor",
4347
"Int4ChooseQParamsAlgorithm",

torchao/quantization/quantize_/workflows/float8/cutlass_semi_sparse_fp8_tensor.py

Lines changed: 0 additions & 56 deletions
This file was deleted.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import List
7+
8+
import torch
9+
10+
from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
11+
from torchao.quantization.quant_primitives import (
12+
_choose_scale_float8,
13+
_quantize_affine_float8,
14+
)
15+
from torchao.utils import TorchAOBaseTensor
16+
17+
__all__ = ["Float8SemiSparseTensor"]
18+
aten = torch.ops.aten
19+
20+
21+
class Float8SemiSparseTensor(TorchAOBaseTensor):
22+
tensor_data_names = ["sparse", "scale", "meta"]
23+
24+
def __new__(
25+
cls,
26+
sparse: torch.Tensor,
27+
meta: torch.Tensor,
28+
scale: torch.Tensor,
29+
):
30+
kwargs = {}
31+
kwargs["device"] = sparse.device
32+
kwargs["dtype"] = scale.dtype
33+
kwargs["requires_grad"] = False
34+
shape = (sparse.shape[0], 2 * sparse.shape[-1])
35+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
36+
37+
def __init__(
38+
self,
39+
sparse: torch.Tensor,
40+
meta: torch.Tensor,
41+
scale: torch.Tensor,
42+
):
43+
super().__init__()
44+
self.sparse = sparse
45+
self.meta = meta
46+
self.scale = scale
47+
48+
def _quantization_type(self):
49+
return f"shape={self.shape}, device={self.device}, dtype={self.dtype}"
50+
51+
@classmethod
52+
def from_hp(
53+
cls,
54+
w: torch.Tensor,
55+
block_size: List[int],
56+
):
57+
from torchao.sparsity.utils import mask_creator
58+
59+
dense = w * mask_creator(w).bool()
60+
61+
scale = _choose_scale_float8(
62+
dense,
63+
block_size=block_size,
64+
float8_dtype=torch.float8_e4m3fn,
65+
)
66+
67+
w_fp8 = _quantize_affine_float8(
68+
dense,
69+
scale=scale,
70+
float8_dtype=torch.float8_e4m3fn,
71+
)
72+
73+
sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8)
74+
75+
return cls(
76+
sparse,
77+
meta,
78+
scale,
79+
)
80+
81+
82+
implements = Float8SemiSparseTensor.implements
83+
implements_torch_function = Float8SemiSparseTensor.implements_torch_function
84+
85+
86+
@implements(aten.linear.default)
87+
@implements_torch_function(torch.nn.functional.linear)
88+
def _(func, types, args, kwargs):
89+
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
90+
91+
input_tensor, weight_tensor, bias = (
92+
args[0],
93+
args[1],
94+
args[2] if len(args) > 2 else None,
95+
)
96+
97+
input = input_tensor.qdata
98+
input_scale = input_tensor.scale
99+
weight = weight_tensor.sparse
100+
weight_meta = weight_tensor.meta
101+
weight_scale = weight_tensor.scale
102+
out_dtype = input_tensor.dtype
103+
104+
out = rowwise_scaled_linear_sparse_cutlass_f8f8(
105+
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
106+
)
107+
108+
return out
109+
110+
111+
Float8SemiSparseTensor.__module__ = "torchao.quantization"
112+
113+
# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True`
114+
torch.serialization.add_safe_globals([Float8SemiSparseTensor])

0 commit comments

Comments
 (0)