44
55import torch
66from torch import nn
7-
8- from torchao .sparsity import (
9- apply_fake_sparsity ,
10- sparsify_ ,
11- semi_sparse_weight ,
12- )
7+ from torch .testing ._internal import common_utils
138from torchao .dtypes import MarlinSparseLayoutType , SemiSparseLayoutType
149from torchao .quantization .quant_api import (
10+ int4_weight_only ,
1511 int8_dynamic_activation_int8_weight ,
1612 quantize_ ,
17- int4_weight_only ,
1813)
19- from torchao .utils import TORCH_VERSION_AT_LEAST_2_3
20- from torch .testing ._internal .common_utils import TestCase
14+
15+ from torchao .sparsity import apply_fake_sparsity , semi_sparse_weight , sparsify_
16+ from torchao .utils import TORCH_VERSION_AFTER_2_5 , TORCH_VERSION_AT_LEAST_2_3 , TORCH_VERSION_AT_LEAST_2_5 , TORCH_VERSION_AT_LEAST_2_4
2117
2218
2319logging .basicConfig (
2420 format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO
2521)
2622
27- class TestSemiStructuredSparse (TestCase ):
23+
24+ class TestSemiStructuredSparse (common_utils .TestCase ):
2825
2926 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "pytorch 2.3+ feature" )
3027 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -37,6 +34,7 @@ def test_sparse(self):
3734 )
3835 .half ()
3936 .cuda ()
37+ .eval ()
4038 )
4139
4240 apply_fake_sparsity (model )
@@ -45,13 +43,17 @@ def test_sparse(self):
4543 sparsify_ (model , semi_sparse_weight ())
4644 sparse_result = model (input )
4745
48- assert torch .allclose (dense_result , sparse_result , rtol = 1e-3 , atol = 1e-3 )
46+ torch .testing . assert_close (dense_result , sparse_result , rtol = 1e-3 , atol = 1e-3 )
4947
50- class TestQuantSemiSparse (TestCase ):
5148
52- @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "pytorch 2.3+ feature" )
49+ class TestQuantSemiSparse (common_utils .TestCase ):
50+
51+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "pytorch 2.5+ feature" )
5352 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
54- def test_quant_semi_sparse (self ):
53+ @common_utils .parametrize ("compile" , [True , False ])
54+ def test_quant_semi_sparse (self , compile ):
55+ torch .sparse .SparseSemiStructuredTensor ._FORCE_CUTLASS = False
56+
5557 input = torch .rand ((128 , 128 )).half ().cuda ()
5658 model = (
5759 nn .Sequential (
@@ -60,19 +62,27 @@ def test_quant_semi_sparse(self):
6062 )
6163 .half ()
6264 .cuda ()
65+ .eval ()
6366 )
6467 apply_fake_sparsity (model )
6568 model_copy = copy .deepcopy (model )
6669 quantize_ (model_copy , int8_dynamic_activation_int8_weight ())
6770 dense_result = model_copy (input )
6871
69- quantize_ (model , int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ()))
72+ quantize_ (
73+ model ,
74+ int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ()),
75+ )
76+ if compile :
77+ model = torch .compile (model )
7078 sparse_result = model (input )
7179
72- assert torch .allclose (dense_result , sparse_result , rtol = 1e-2 , atol = 1e-2 )
80+ torch .testing . assert_close (dense_result , sparse_result , rtol = 1e-2 , atol = 1e-2 )
7381
82+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_5 , "pytorch 2.5+ feature" )
7483 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
75- def test_sparse_marlin (self ):
84+ @common_utils .parametrize ("compile" , [True , False ])
85+ def test_sparse_marlin (self , compile ):
7686 input = torch .rand ((256 , 256 )).half ().cuda ()
7787 model = (
7888 nn .Sequential (
@@ -81,6 +91,7 @@ def test_sparse_marlin(self):
8191 )
8292 .half ()
8393 .cuda ()
94+ .eval ()
8495 )
8596
8697 apply_fake_sparsity (model )
@@ -92,9 +103,101 @@ def test_sparse_marlin(self):
92103
93104 # Sparse + quantized
94105 quantize_ (model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
106+ if compile :
107+ model = torch .compile (model )
95108 sparse_result = model (input )
96109
97- assert torch .allclose (dense_result , sparse_result , atol = 3e-1 ), "Results are not close"
110+ torch .testing .assert_close (dense_result , sparse_result , atol = 3e-1 , rtol = 3e-1 )
111+
112+
113+ class TestBlockSparseWeight (common_utils .TestCase ):
114+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "pytorch 2.4+ feature due to need for custom op support" )
115+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
116+ @common_utils .parametrize ("compile" , [True , False ])
117+ def test_sparse (self , compile ):
118+ input = torch .rand ((1024 , 1024 )).half ().cuda ()
119+ model = (
120+ nn .Sequential (
121+ nn .Linear (1024 , 2048 ),
122+ nn .Linear (2048 , 1024 ),
123+ )
124+ .half ()
125+ .cuda ()
126+ .eval ()
127+ )
128+
129+ from torchao .sparsity .utils import create_block_sparse_tensor
130+
131+ M , N = model [0 ].weight .shape
132+ model [0 ].weight .data = create_block_sparse_tensor (M , N , 64 , 0.5 , torch .float16 )
133+ M , N = model [1 ].weight .shape
134+ model [1 ].weight .data = create_block_sparse_tensor (M , N , 64 , 0.5 , torch .float16 )
135+ dense_result = model (input )
136+
137+ from torchao .sparsity .prototype .superblock .blocksparse import (
138+ block_sparse_weight ,
139+ )
140+
141+ sparsify_ (model , block_sparse_weight (blocksize = 64 ))
142+ # if compile:
143+ # model = torch.compile(model)
144+ sparse_result = model (input )
145+
146+ torch .testing .assert_close (dense_result , sparse_result , rtol = 1e-3 , atol = 1e-3 )
147+
148+
149+ class TestQuantBlockSparseWeight (common_utils .TestCase ):
150+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_5 , "pytorch 2.6+ feature" )
151+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
152+ @common_utils .parametrize ("compile" , [True , False ])
153+ def test_sparse (self , compile ):
154+ input = torch .rand ((256 , 128 )).to (torch .bfloat16 ).cuda ()
155+ model = (
156+ nn .Sequential (
157+ nn .Linear (128 , 256 ),
158+ nn .Linear (256 , 128 ),
159+ )
160+ .to (torch .bfloat16 )
161+ .cuda ()
162+ .eval ()
163+ )
164+ from torchao .sparsity .prototype .superblock .blocksparse import (
165+ blocksparse_int_addmm ,
166+ )
167+ from torchao .sparsity .utils import create_block_sparse_tensor
168+
169+ M , N = model [0 ].weight .shape
170+ model [0 ].weight .data = (
171+ create_block_sparse_tensor (M , N , 64 , 0.5 , torch .bfloat16 )
172+ * torch .rand (M , N , dtype = torch .bfloat16 ).cuda ()
173+ )
174+ M , N = model [1 ].weight .shape
175+ model [1 ].weight .data = create_block_sparse_tensor (M , N , 64 , 0.5 , torch .bfloat16 )
176+
177+ model_copy = copy .deepcopy (model )
178+
179+ quantize_ (model_copy , int8_dynamic_activation_int8_weight ())
180+ reference = model_copy (input )
181+
182+ from torchao .dtypes .affine_quantized_tensor import BlockSparseLayoutType
183+
184+ quantize_ (
185+ model ,
186+ int8_dynamic_activation_int8_weight (
187+ layout_type = BlockSparseLayoutType (blocksize = 64 )
188+ ),
189+ )
190+ if compile :
191+ model = torch .compile (model )
192+ sparse_result = model (input )
193+
194+ torch .testing .assert_close (reference , sparse_result , rtol = 1e-1 , atol = 1e-1 )
195+
196+
197+ common_utils .instantiate_parametrized_tests (TestSemiStructuredSparse )
198+ common_utils .instantiate_parametrized_tests (TestQuantSemiSparse )
199+ common_utils .instantiate_parametrized_tests (TestBlockSparseWeight )
200+ common_utils .instantiate_parametrized_tests (TestQuantBlockSparseWeight )
98201
99202if __name__ == "__main__" :
100203 unittest .main ()
0 commit comments