1818from torchao .quantization import (
1919 Float8DynamicActivationFloat8WeightConfig ,
2020 Float8WeightOnlyConfig ,
21+ Granularity ,
2122 PerBlock ,
2223 PerRow ,
2324 PerTensor ,
4243class ToyLinearModel (torch .nn .Module ):
4344 def __init__ (self , in_features , out_features , bias ):
4445 super ().__init__ ()
46+ self .in_features = in_features
47+ self .out_features = out_features
4548 self .linear1 = torch .nn .Linear (in_features , out_features , bias = bias )
4649 self .linear2 = torch .nn .Linear (out_features , in_features , bias = bias )
4750
@@ -50,6 +53,21 @@ def forward(self, x):
5053 x = self .linear2 (x )
5154 return x
5255
56+ def check_weight_scaling (self , granularity : Granularity ):
57+ qs1 = self .linear1 .weight .scale
58+ qs2 = self .linear2 .weight .scale
59+ N , K = (self .out_features , self .in_features )
60+ if granularity == PerTensor ():
61+ assert qs1 .shape == (1 , 1 )
62+ assert qs2 .shape == (1 , 1 )
63+ elif granularity == PerRow ():
64+ assert qs1 .shape == (N , 1 )
65+ assert qs2 .shape == (K , 1 )
66+ else :
67+ assert granularity == (PerBlock ([1 , 128 ]), PerBlock ([128 , 128 ]))
68+ assert qs1 .shape == (N // 128 , K // 128 )
69+ assert qs2 .shape == (K // 128 , N // 128 )
70+
5371
5472class ToyConvModel (torch .nn .Module ):
5573 def __init__ (
@@ -73,6 +91,47 @@ def forward(self, x):
7391 return self .conv (x )
7492
7593
94+ class ToyLoRAModel (torch .nn .Module ):
95+ def __init__ (
96+ self ,
97+ in_features : int ,
98+ out_features : int ,
99+ lora_rank : int ,
100+ device : torch .device ,
101+ ):
102+ super ().__init__ ()
103+ self .in_features = in_features
104+ self .out_features = out_features
105+ self .linear = torch .nn .Linear (
106+ in_features ,
107+ out_features ,
108+ bias = False ,
109+ device = device ,
110+ )
111+ self .lora_A = torch .nn .Parameter (
112+ torch .randn (in_features , lora_rank , device = device ),
113+ )
114+ self .lora_B = torch .nn .Parameter (
115+ torch .randn (lora_rank , out_features , device = device ),
116+ )
117+
118+ def forward (self , x ):
119+ matmul_out = torch .matmul (x , self .linear .weight .t ())
120+ lora_out = x @ self .lora_A @ self .lora_B
121+ return matmul_out + lora_out
122+
123+ def check_weight_scaling (self , granularity : Granularity ):
124+ qs = self .linear .weight .scale
125+ N , K = (self .out_features , self .in_features )
126+ if granularity == PerTensor ():
127+ assert qs .shape == (1 , 1 )
128+ elif granularity == PerRow ():
129+ assert qs .shape == (N , 1 )
130+ else :
131+ assert granularity == (PerBlock ((1 , 128 )), PerBlock ((128 , 128 )))
132+ assert qs .shape == (N // 128 , K // 128 )
133+
134+
76135# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
77136@unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
78137@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -112,10 +171,75 @@ def test_fp8_linear_variants(
112171 dtype : torch .dtype ,
113172 mode : str ,
114173 compile : bool ,
115- granularity ,
174+ granularity : Granularity ,
116175 kernel_preference : KernelPreference ,
117176 sizes : Tuple ,
118177 bias : bool ,
178+ ):
179+ _ , N , K = sizes
180+ self ._test_fp8_matmul_model (
181+ dtype ,
182+ mode ,
183+ compile ,
184+ granularity ,
185+ kernel_preference ,
186+ sizes ,
187+ bias ,
188+ ToyLinearModel (K , N , bias ),
189+ )
190+
191+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
192+ @unittest .skipIf (
193+ not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
194+ )
195+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
196+ @common_utils .parametrize ("mode" , ["dynamic" , "weight-only" ])
197+ @common_utils .parametrize ("compile" , [True , False ])
198+ @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
199+ @common_utils .parametrize (
200+ "kernel_preference" ,
201+ [KernelPreference .AUTO , KernelPreference .TORCH , KernelPreference .FBGEMM ],
202+ )
203+ # Inputs are (M,..), K, N
204+ @common_utils .parametrize (
205+ "sizes" ,
206+ [
207+ ((128 ,), 256 , 128 ),
208+ ((32 , 128 ), 64 , 256 ),
209+ ],
210+ )
211+ def test_fp8_matmul_lora_variants (
212+ self ,
213+ dtype : torch .dtype ,
214+ mode : str ,
215+ compile : bool ,
216+ granularity : Granularity ,
217+ kernel_preference : KernelPreference ,
218+ sizes : Tuple ,
219+ ):
220+ _ , N , K = sizes
221+ model = ToyLoRAModel (K , N , lora_rank = 8 , device = torch .device ("cpu" ))
222+ self ._test_fp8_matmul_model (
223+ dtype ,
224+ mode ,
225+ compile ,
226+ granularity ,
227+ kernel_preference ,
228+ sizes ,
229+ bias = False ,
230+ model = model .to ("cuda" ),
231+ )
232+
233+ def _test_fp8_matmul_model (
234+ self ,
235+ dtype : torch .dtype ,
236+ mode : str ,
237+ compile : bool ,
238+ granularity : Granularity ,
239+ kernel_preference : KernelPreference ,
240+ sizes : Tuple ,
241+ bias : bool ,
242+ model : torch .nn .Module ,
119243 ):
120244 if isinstance (granularity , PerTensor ):
121245 if kernel_preference is KernelPreference .FBGEMM :
@@ -172,9 +296,7 @@ def test_fp8_linear_variants(
172296 with error_context :
173297 M , N , K = sizes
174298 input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
175-
176- # Create a linear layer with bfloat16 dtype
177- model = ToyLinearModel (K , N , bias ).eval ().to (dtype ).to ("cuda" )
299+ model = model .eval ().to (dtype ).to ("cuda" )
178300
179301 quantized_model = copy .deepcopy (model )
180302
@@ -190,18 +312,7 @@ def test_fp8_linear_variants(
190312 quantize_ (quantized_model , config )
191313
192314 # ensure weight scaling is what we expect
193- qs1 = quantized_model .linear1 .weight .scale
194- qs2 = quantized_model .linear2 .weight .scale
195- if granularity == PerTensor ():
196- assert qs1 .shape == (1 , 1 )
197- assert qs2 .shape == (1 , 1 )
198- elif granularity == PerRow ():
199- assert qs1 .shape == (N , 1 )
200- assert qs2 .shape == (K , 1 )
201- else :
202- assert granularity == (PerBlock ([1 , 128 ]), PerBlock ([128 , 128 ]))
203- assert qs1 .shape == (N // 128 , K // 128 )
204- assert qs2 .shape == (K // 128 , N // 128 )
315+ quantized_model .check_weight_scaling (granularity )
205316
206317 if compile :
207318 quantized_model = torch .compile (quantized_model , fullgraph = True )
@@ -807,6 +918,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
807918
808919 self .assertEqual (sliced_dequantized , sliced_original )
809920
921+ def test_to_dtype_layout (self ):
922+ x = torch .randn (128 , 512 , device = "cuda" , dtype = torch .bfloat16 )
923+ x_fp8 = Float8Tensor .from_hp (x )
924+ y_fp8 = torch .ops .aten .to .dtype_layout (
925+ x_fp8 , dtype = x_fp8 .dtype , layout = x_fp8 .layout , device = "cpu"
926+ )
927+ self .assertEqual (y_fp8 .dtype , x_fp8 .dtype )
928+ self .assertEqual (y_fp8 .layout , x_fp8 .layout )
929+ self .assertEqual (y_fp8 .device , torch .device ("cpu" ))
930+
931+ def test_has_compatible_shallow_copy_type (self ):
932+ x1 = torch .randn (128 , 512 , device = "cuda" , dtype = torch .bfloat16 )
933+ x2 = torch .randn (128 , 512 , device = "cuda" , dtype = torch .bfloat16 )
934+ x3 = torch .randn (128 , 256 , device = "cuda" , dtype = torch .bfloat16 )
935+ x1_fp8 = Float8Tensor .from_hp (x1 )
936+ x2_fp8 = Float8Tensor .from_hp (x2 )
937+ x3_fp8 = Float8Tensor .from_hp (x3 )
938+ self .assertFalse (torch ._has_compatible_shallow_copy_type (x1 , x2_fp8 ))
939+ self .assertFalse (torch ._has_compatible_shallow_copy_type (x1_fp8 , x2 ))
940+ self .assertTrue (torch ._has_compatible_shallow_copy_type (x1_fp8 , x2_fp8 ))
941+ # Wrong shape
942+ self .assertFalse (torch ._has_compatible_shallow_copy_type (x1_fp8 , x3_fp8 ))
943+
944+ def test_transpose (self ):
945+ x = torch .randn (128 , 512 , device = "cuda" , dtype = torch .bfloat16 )
946+ x_fp8 = Float8Tensor .from_hp (x )
947+ x_fp8_t = x_fp8 .t ()
948+ torch .testing .assert_close (x_fp8_t .qdata , x_fp8 .qdata .t (), atol = 0 , rtol = 0 )
949+ torch .testing .assert_close (x_fp8_t .scale , x_fp8 .scale .t (), atol = 0 , rtol = 0 )
950+ self .assertEqual (x_fp8 .block_size , (1 , 512 ), atol = 0 , rtol = 0 )
951+ self .assertEqual (x_fp8_t .block_size , (512 , 1 ), atol = 0 , rtol = 0 )
952+
810953
811954common_utils .instantiate_parametrized_tests (TestFloat8Tensor )
812955
0 commit comments