Skip to content

Commit 4b5b5ee

Browse files
authored
int8 dynamic quant + bsr support (#821)
This PR, adds in int8 dynamicquant + bsr support. Changes: * Use i8i8 -> bf16 matmul to maintain accuracy * Added a block sparse layout type to AffineQuantizedTensor + check/impl. * Cleaned up benchmark.py script and add a single line `benchmark.sh` file for acceleration numbers * Updated eval.py and added a single line `evaluate.sh` file for accuracy numbers * Lots of lint formatting and README updates * torch.compile now working and is correct
1 parent da0bbe3 commit 4b5b5ee

File tree

16 files changed

+1442
-922
lines changed

16 files changed

+1442
-922
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,24 @@
44

55
import torch
66
from torch import nn
7-
8-
from torchao.sparsity import (
9-
apply_fake_sparsity,
10-
sparsify_,
11-
semi_sparse_weight,
12-
)
7+
from torch.testing._internal import common_utils
138
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
149
from torchao.quantization.quant_api import (
10+
int4_weight_only,
1511
int8_dynamic_activation_int8_weight,
1612
quantize_,
17-
int4_weight_only,
1813
)
19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
20-
from torch.testing._internal.common_utils import TestCase
14+
15+
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
16+
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4
2117

2218

2319
logging.basicConfig(
2420
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
2521
)
2622

27-
class TestSemiStructuredSparse(TestCase):
23+
24+
class TestSemiStructuredSparse(common_utils.TestCase):
2825

2926
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
3027
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -37,6 +34,7 @@ def test_sparse(self):
3734
)
3835
.half()
3936
.cuda()
37+
.eval()
4038
)
4139

4240
apply_fake_sparsity(model)
@@ -45,13 +43,17 @@ def test_sparse(self):
4543
sparsify_(model, semi_sparse_weight())
4644
sparse_result = model(input)
4745

48-
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
46+
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
4947

50-
class TestQuantSemiSparse(TestCase):
5148

52-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
49+
class TestQuantSemiSparse(common_utils.TestCase):
50+
51+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
5352
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
54-
def test_quant_semi_sparse(self):
53+
@common_utils.parametrize("compile", [True, False])
54+
def test_quant_semi_sparse(self, compile):
55+
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
56+
5557
input = torch.rand((128, 128)).half().cuda()
5658
model = (
5759
nn.Sequential(
@@ -60,19 +62,27 @@ def test_quant_semi_sparse(self):
6062
)
6163
.half()
6264
.cuda()
65+
.eval()
6366
)
6467
apply_fake_sparsity(model)
6568
model_copy = copy.deepcopy(model)
6669
quantize_(model_copy, int8_dynamic_activation_int8_weight())
6770
dense_result = model_copy(input)
6871

69-
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
72+
quantize_(
73+
model,
74+
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
75+
)
76+
if compile:
77+
model = torch.compile(model)
7078
sparse_result = model(input)
7179

72-
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
80+
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
7381

82+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
7483
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
75-
def test_sparse_marlin(self):
84+
@common_utils.parametrize("compile", [True, False])
85+
def test_sparse_marlin(self, compile):
7686
input = torch.rand((256, 256)).half().cuda()
7787
model = (
7888
nn.Sequential(
@@ -81,6 +91,7 @@ def test_sparse_marlin(self):
8191
)
8292
.half()
8393
.cuda()
94+
.eval()
8495
)
8596

8697
apply_fake_sparsity(model)
@@ -92,9 +103,101 @@ def test_sparse_marlin(self):
92103

93104
# Sparse + quantized
94105
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
106+
if compile:
107+
model = torch.compile(model)
95108
sparse_result = model(input)
96109

97-
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
110+
torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)
111+
112+
113+
class TestBlockSparseWeight(common_utils.TestCase):
114+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
115+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
116+
@common_utils.parametrize("compile", [True, False])
117+
def test_sparse(self, compile):
118+
input = torch.rand((1024, 1024)).half().cuda()
119+
model = (
120+
nn.Sequential(
121+
nn.Linear(1024, 2048),
122+
nn.Linear(2048, 1024),
123+
)
124+
.half()
125+
.cuda()
126+
.eval()
127+
)
128+
129+
from torchao.sparsity.utils import create_block_sparse_tensor
130+
131+
M, N = model[0].weight.shape
132+
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
133+
M, N = model[1].weight.shape
134+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
135+
dense_result = model(input)
136+
137+
from torchao.sparsity.prototype.superblock.blocksparse import (
138+
block_sparse_weight,
139+
)
140+
141+
sparsify_(model, block_sparse_weight(blocksize=64))
142+
# if compile:
143+
# model = torch.compile(model)
144+
sparse_result = model(input)
145+
146+
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
147+
148+
149+
class TestQuantBlockSparseWeight(common_utils.TestCase):
150+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "pytorch 2.6+ feature")
151+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
152+
@common_utils.parametrize("compile", [True, False])
153+
def test_sparse(self, compile):
154+
input = torch.rand((256, 128)).to(torch.bfloat16).cuda()
155+
model = (
156+
nn.Sequential(
157+
nn.Linear(128, 256),
158+
nn.Linear(256, 128),
159+
)
160+
.to(torch.bfloat16)
161+
.cuda()
162+
.eval()
163+
)
164+
from torchao.sparsity.prototype.superblock.blocksparse import (
165+
blocksparse_int_addmm,
166+
)
167+
from torchao.sparsity.utils import create_block_sparse_tensor
168+
169+
M, N = model[0].weight.shape
170+
model[0].weight.data = (
171+
create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
172+
* torch.rand(M, N, dtype=torch.bfloat16).cuda()
173+
)
174+
M, N = model[1].weight.shape
175+
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
176+
177+
model_copy = copy.deepcopy(model)
178+
179+
quantize_(model_copy, int8_dynamic_activation_int8_weight())
180+
reference = model_copy(input)
181+
182+
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
183+
184+
quantize_(
185+
model,
186+
int8_dynamic_activation_int8_weight(
187+
layout_type=BlockSparseLayoutType(blocksize=64)
188+
),
189+
)
190+
if compile:
191+
model = torch.compile(model)
192+
sparse_result = model(input)
193+
194+
torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)
195+
196+
197+
common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
198+
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)
199+
common_utils.instantiate_parametrized_tests(TestBlockSparseWeight)
200+
common_utils.instantiate_parametrized_tests(TestQuantBlockSparseWeight)
98201

99202
if __name__ == "__main__":
100203
unittest.main()

0 commit comments

Comments
 (0)