1717
1818from torchao .quantization import (
1919 Float8DynamicActivationFloat8WeightConfig ,
20+ Float8Tensor ,
2021 Float8WeightOnlyConfig ,
2122 Granularity ,
2223 PerBlock ,
2526 quantize_ ,
2627)
2728from torchao .quantization .quantize_ .common import KernelPreference
28- from torchao .quantization .quantize_ .workflows .float8 .float8_tensor import Float8Tensor
2929from torchao .quantization .utils import compute_error
3030from torchao .testing .utils import TorchAOIntegrationTestCase
3131from torchao .utils import (
@@ -329,14 +329,13 @@ def _test_fp8_matmul_model(
329329 @unittest .skipIf (
330330 not is_sm_at_least_100 (), "Requires GPU with compute capability >= 10.0"
331331 )
332+ @unittest .skipIf (
333+ not _is_fbgemm_gpu_genai_available (),
334+ "Requires fbgemm_gpu_genai to be installed" ,
335+ )
332336 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
333337 @common_utils .parametrize ("compile" , [True , False ])
334- @common_utils .parametrize ("granularity" , [PerTensor ()])
335338 @common_utils .parametrize ("inference_mode" , [True , False ])
336- @common_utils .parametrize (
337- "kernel_preference" ,
338- [KernelPreference .AUTO ],
339- )
340339 # only test for 3D conv for now
341340 # Inputs are (N, C_in, C_out, D, H, W)
342341 @common_utils .parametrize (
@@ -349,19 +348,14 @@ def test_fp8_conv_variants(
349348 self ,
350349 dtype : torch .dtype ,
351350 compile : bool ,
352- granularity ,
353351 inference_mode : bool ,
354352 kernel_preference : KernelPreference ,
355353 sizes : Tuple ,
356354 ):
357- if (not _is_fbgemm_gpu_genai_available ()) or (not is_sm_at_least_100 ()):
358- return unittest .skip (
359- "Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
360- "fbgemm kernel preference test"
361- )
362-
363- dim = 3
355+ granularity = PerTensor ()
356+ kernel_preference = KernelPreference .AUTO
364357 N , C_in , C_out , D , H , W = sizes
358+ dim = 3
365359 kernel_size = 3
366360
367361 # Note: this is channel last memory format
@@ -404,6 +398,69 @@ def test_fp8_conv_variants(
404398 f"Quantization error is too high got a SQNR of { error } "
405399 )
406400
401+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
402+ @unittest .skipIf (
403+ not is_sm_at_least_100 (), "Requires GPU with compute capability >= 10.0"
404+ )
405+ @unittest .skipIf (
406+ not _is_fbgemm_gpu_genai_available (),
407+ "Requires fbgemm_gpu_genai to be installed" ,
408+ )
409+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
410+ # only test for 3D conv for now
411+ # Inputs are (N, C_in, C_out, D, H, W)
412+ @common_utils .parametrize (
413+ "sizes" ,
414+ [
415+ (4 , 12 , 64 , 32 , 32 , 32 ),
416+ (4 , 16 , 12 , 32 , 32 , 32 ),
417+ ],
418+ )
419+ def test_fp8_conv_skip_quant (
420+ self ,
421+ dtype : torch .dtype ,
422+ sizes : Tuple ,
423+ ):
424+ """Some shapes are not supported so we won't quantize the module
425+ Specifically, we skip quantization when C_in or C_out is not a multiple of 16
426+ """
427+ granularity = PerTensor ()
428+ kernel_preference = KernelPreference .AUTO
429+ N , C_in , C_out , D , H , W = sizes
430+ dim = 3
431+ kernel_size = 3
432+
433+ # Note: this is channel last memory format
434+ input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
435+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
436+ # Create a linear layer with bfloat16 dtype
437+ model = ToyConvModel (
438+ dim ,
439+ C_in ,
440+ C_out ,
441+ kernel_size ,
442+ bias = False ,
443+ padding = 0 ,
444+ dtype = dtype ,
445+ device = "cuda" ,
446+ ).eval ()
447+
448+ quantized_model = copy .deepcopy (model )
449+
450+ config = Float8DynamicActivationFloat8WeightConfig (
451+ granularity = granularity ,
452+ kernel_preference = kernel_preference ,
453+ )
454+
455+ _is_conv3d = lambda m , fqn : isinstance (m , torch .nn .Conv3d )
456+
457+ quantize_ (quantized_model , config , filter_fn = _is_conv3d )
458+ assert not isinstance (quantized_model .conv .weight , Float8Tensor )
459+
460+ output_original = model (input_tensor )
461+ output_quantized = quantized_model (input_tensor )
462+ self .assertEqual (output_original , output_quantized )
463+
407464 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
408465 @unittest .skipIf (
409466 not is_sm_at_least_90 (),
0 commit comments