44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import tempfile
87import unittest
9-
8+ from torchao .quantization .quantize_ .workflows .float8 .float8_semi_sparse_tensor import Float8SemiSparseTensor
9+ from torchao .quantization .quantize_ .workflows .float8 .float8_tensor import Float8Tensor
10+ from torchao .float8 .inference import Float8MMConfig
1011import torch
1112from torch .testing ._internal .common_utils import (
1213 TestCase ,
1314 instantiate_parametrized_tests ,
1415 parametrize ,
1516 run_tests ,
1617)
17-
18- from torchao .quantization import (
19- Float8WeightOnlyConfig ,
20- quantize_ ,
21- )
22- from torchao .quantization .utils import compute_error
2318from torchao .sparsity .sparse_api import apply_fake_sparsity
2419from torchao .testing .utils import skip_if_rocm
25- from torchao .utils import torch_version_at_least
20+ from torchao .utils import is_sm_at_least_90
2621
27- BF16_ACT_CONFIG = Float8WeightOnlyConfig (
28- group_size = 128 ,
29- packing_format = "cutlass_semi_sparse" ,
30- )
3122
32-
33- @unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
23+ @unittest .skipIf (not is_sm_at_least_90 (), "Need H100+ to run" )
3424@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
3525class TestFloat8SemiSparseTensor (TestCase ):
3626 def setUp (self ):
3727 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
3828
3929 @skip_if_rocm ("ROCm enablement in progress" )
40- @parametrize ("config" , [BF16_ACT_CONFIG ])
4130 @parametrize (
4231 "sizes" ,
4332 [
4433 ((128 ,), 256 , 128 ),
4534 ((32 , 128 ), 512 , 128 ),
46- ((2 , 32 , 128 ), 256 , 12 ),
35+ ((2 , 32 , 128 ), 256 , 128 ),
4736 ],
4837 )
49- def test_linear (self , config , sizes ):
38+ def test_sparse_vs_dense_fp8 (self , sizes ):
5039 dtype = torch .bfloat16
5140 device = "cuda"
5241
@@ -55,52 +44,20 @@ def test_linear(self, config, sizes):
5544 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
5645
5746 apply_fake_sparsity (linear )
58- original = linear (input )
59- quantize_ (linear , config )
60- quantized = linear (input )
61- self .assertTrue (compute_error (original , quantized ) > 20 )
62-
63- compiled_linear = torch .compile (linear )
64- quantized_and_compiled = compiled_linear (input )
65- self .assertTrue (compute_error (original , quantized_and_compiled ) > 20 )
66-
67- @skip_if_rocm ("ROCm enablement in progress" )
68- @unittest .skip ("Fix later" )
69- @parametrize ("config" , [BF16_ACT_CONFIG ])
70- def test_to_device (self , config ):
71- for device in self .GPU_DEVICES :
72- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
73- quantize_ (linear , config )
74- linear .to (device )
75-
76- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
77- quantize_ (linear , config )
78- linear .to (device = device )
79-
80- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
81- quantize_ (linear , config )
82- linear .to (device )
83-
84- @skip_if_rocm ("ROCm enablement in progress" )
85- @parametrize ("config" , [BF16_ACT_CONFIG ])
86- def test_module_path (self , config ):
87- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88- quantize_ (linear .cuda (), config )
89- self .assertEqual (
90- str (type (linear .weight )),
91- "<class 'torchao.quantization.Float8SemiSparseTensor'>" ,
47+
48+ mm_config = Float8MMConfig (use_fast_accum = True )
49+ input_fp8 = Float8Tensor .from_hp (input , float8_dtype = torch .float8_e4m3fn , mm_config = mm_config )
50+
51+ weight_fp8 = Float8Tensor .from_hp (linear .weight .data , float8_dtype = torch .float8_e4m3fn , mm_config = mm_config )
52+ dense_output = torch .nn .functional .linear (input_fp8 , weight_fp8 , linear .bias )
53+
54+ weight_sparse_fp8 = Float8SemiSparseTensor .from_hp (linear .weight .data , [1 , K ])
55+ sparse_output = torch .nn .functional .linear (input_fp8 , weight_sparse_fp8 , linear .bias )
56+
57+ torch .testing .assert_close (
58+ dense_output , sparse_output , atol = 3e-1 , rtol = 3e-1
9259 )
9360
94- with tempfile .NamedTemporaryFile () as f :
95- torch .save (linear .state_dict (), f )
96- f .seek (0 )
97- state_dict = torch .load (f )
98- self .assertEqual (
99- str (type (state_dict ["weight" ])),
100- "<class 'torchao.quantization.Float8SemiSparseTensor'>" ,
101- )
102-
103-
10461instantiate_parametrized_tests (TestFloat8SemiSparseTensor )
10562
10663
0 commit comments