From 22d7227de3c1f22ea7c43b7d59dec24fe487c335 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 8 Nov 2025 00:31:18 +0000 Subject: [PATCH] Add per tensor fp8 conv2d support Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants --- .../workflows/float8/test_float8_tensor.py | 153 ++++++++++++++---- torchao/quantization/quant_api.py | 14 +- .../workflows/float8/float8_tensor.py | 97 ++++++++++- 3 files changed, 218 insertions(+), 46 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4bc106a60f..df11b71e66 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -87,6 +87,8 @@ def __init__( ) if dim == 3: self.conv = self.conv.to(memory_format=torch.channels_last_3d) + elif dim == 2: + self.conv = self.conv.to(memory_format=torch.channels_last) def forward(self, x): return self.conv(x) @@ -337,12 +339,14 @@ def _test_fp8_matmul_model( @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("inference_mode", [True, False]) - # only test for 3D conv for now - # Inputs are (N, C_in, C_out, D, H, W) + # test for 2D/3D conv + # Inputs are (N, C_in, C_out, (D, H, W) or + # (N, C_in, C_out, (H, W) @common_utils.parametrize( "sizes", [ - (4, 16, 64, 32, 32, 32), + (4, 16, 64, (32, 32, 32)), + (4, 16, 64, (32, 32)), ], ) def test_fp8_conv_variants( @@ -350,20 +354,28 @@ def test_fp8_conv_variants( dtype: torch.dtype, compile: bool, inference_mode: bool, - kernel_preference: KernelPreference, sizes: Tuple, ): + torch.compiler.reset() granularity = PerTensor() kernel_preference = KernelPreference.AUTO - N, C_in, C_out, D, H, W = sizes - dim = 3 + + N, C_in, C_out, spatial_dims = sizes + dim = len(spatial_dims) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + assert dim in convs, f"Unsupported dim: {dim}" + conv_class = convs[dim] + kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") - input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + if dim == 3: + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + else: + assert dim == 2 + input_tensor = input_tensor.to(memory_format=torch.channels_last) - # Create a linear layer with bfloat16 dtype model = ToyConvModel( dim, C_in, @@ -382,9 +394,9 @@ def test_fp8_conv_variants( kernel_preference=kernel_preference, ) - _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + _is_conv = lambda m, fqn: isinstance(m, conv_class) - quantize_(quantized_model, config, filter_fn=_is_conv3d) + quantize_(quantized_model, config, filter_fn=_is_conv) if compile: quantized_model = torch.compile(quantized_model, fullgraph=True) @@ -408,13 +420,16 @@ def test_fp8_conv_variants( "Requires fbgemm_gpu_genai to be installed", ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - # only test for 3D conv for now - # Inputs are (N, C_in, C_out, D, H, W) + # test for 2D/3D conv + # Inputs are (N, C_in, C_out, (D, H, W) or + # (N, C_in, C_out, (H, W) @common_utils.parametrize( "sizes", [ - (4, 12, 64, 32, 32, 32), - (4, 16, 12, 32, 32, 32), + (4, 12, 64, (32, 32, 32)), + (4, 16, 12, (32, 32, 32)), + (4, 12, 64, (32, 32)), + (4, 16, 12, (32, 32)), ], ) def test_fp8_conv_skip_quant( @@ -427,14 +442,23 @@ def test_fp8_conv_skip_quant( """ granularity = PerTensor() kernel_preference = KernelPreference.AUTO - N, C_in, C_out, D, H, W = sizes - dim = 3 + + N, C_in, C_out, spatial_dims = sizes + + dim = len(spatial_dims) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + assert dim in convs, f"Unsupported dim: {dim}" + conv_class = convs[dim] + kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") - input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) - # Create a linear layer with bfloat16 dtype + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + if dim == 3: + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + else: + input_tensor = input_tensor.to(memory_format=torch.channels_last) + model = ToyConvModel( dim, C_in, @@ -453,9 +477,9 @@ def test_fp8_conv_skip_quant( kernel_preference=kernel_preference, ) - _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + _is_conv = lambda m, fqn: isinstance(m, conv_class) - quantize_(quantized_model, config, filter_fn=_is_conv3d) + quantize_(quantized_model, config, filter_fn=_is_conv) assert not isinstance(quantized_model.conv.weight, Float8Tensor) output_original = model(input_tensor) @@ -832,7 +856,6 @@ def test_index_select(self): ], ) def test_unsqueeze_operation(self, granularity, sizes): - """Test aten.unsqueeze.default operation on Float8Tensor""" config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" @@ -845,7 +868,7 @@ def test_unsqueeze_operation(self, granularity, sizes): original_weight = linear.weight original_shape = original_weight.shape - # Test unsqueeze operation at dim=0 (only supported dimension) + # Test unsqueeze operation at dim=0 unsqueezed_weight = original_weight.unsqueeze(0) # Verify the unsqueezed tensor has correct shape @@ -887,22 +910,84 @@ def test_unsqueeze_operation(self, granularity, sizes): self.assertEqual(unsqueezed_dequant, expected_dequant) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_unsqueeze_error_cases(self, granularity): - """Test error cases for aten.unsqueeze.default operation""" + def test_unsqueeze_conv2d_weight(self): + granularity = PerTensor() config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" + N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32) + dim = len(spatial_dims) + kernel_size = 3 - # Create a linear layer and quantize it - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) - quantize_(linear, config) + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device) + input_tensor = input_tensor.to(memory_format=torch.channels_last) + model = ToyConvModel( + dim, + C_in, + C_out, + kernel_size, + bias=False, + padding=0, + dtype=dtype, + device=device, + ).eval() + + quantized_model = copy.deepcopy(model) + + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + ) + + _is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d) - weight = linear.weight + quantize_(quantized_model, config, filter_fn=_is_conv) - # Test that unsqueezing on unsupported dimensions raises an error - with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"): - weight.unsqueeze(1) # dim=1 should not be supported + original_weight = quantized_model.conv.weight + original_shape = original_weight.shape + + # Test unsqueeze operation at dim=2 + unsqueezed_weight = original_weight.unsqueeze(2) + + # Verify the unsqueezed tensor has correct shape + original_shape_list = list(original_shape) + expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:] + scale_shape_list = list(original_weight.scale.shape) + expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:] + + self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape)) + # Verify qdata and scale shapes + expected_qdata_shape = expected_shape + + self.assertEqual( + unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape) + ) + self.assertEqual( + unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape) + ) + + # Verify block_size is correctly updated + expected_block_size = [] + for i in range(len(expected_shape)): + expected_block_size.append(expected_shape[i] // expected_scale_shape[i]) + + self.assertEqual(unsqueezed_weight.block_size, expected_block_size) + + # Test that metadata is preserved + self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config) + self.assertEqual( + unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs + ) + self.assertEqual( + unsqueezed_weight.kernel_preference, original_weight.kernel_preference + ) + self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype) + + # Test numerical correctness + original_dequant = original_weight.dequantize() + unsqueezed_dequant = unsqueezed_weight.dequantize() + expected_dequant = original_dequant.unsqueeze(2) + + self.assertEqual(unsqueezed_dequant, expected_dequant) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize("slice_dim", [0, 1, 2]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e3a75bbb3e..09c2edcd9f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1816,13 +1816,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity - if weight.dim() == 5: - # weights for conv3d + # Note: right now we assume it's weights of conv2d and conv3d purely based + # on the dimension of weight, currently there is no conflict with linear 2d + # and moe weights 3d + # if we need to support conv1d, which also has 3d weight, we may have to + # pass around the module as well to distinguish between conv1d and 3d moe weight + if weight.dim() in [4, 5]: + # weights for conv2d or 3d assert isinstance(activation_granularity, PerTensor) and isinstance( weight_granularity, PerTensor - ), "5D tensor only supports per tensor activation and weight quantization" + ), "4D/5D tensor only supports per tensor activation and weight quantization" - # weight dim: (C_out, C_in, K1, K2, K3) + # conv3d weight dim: (C_out, C_in, K1, K2, K3) + # conv2d weight dim: (C_out, C_in, K1, K2) # skip quantization when either C_out or C_in # is not a multiple of 16 if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index abb9ddc1f9..733d7a17a5 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -539,6 +539,7 @@ def _quantize_and_scaled_conv3d( # move C_in to last dim # after permute: (C_out, K1, K2, K3, C_in) + weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1]) assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), ( @@ -574,10 +575,71 @@ def _(func, types, args, kwargs): groups, ) = args assert not transposed, "transposed conv is not supported currently" - assert tuple(output_padding) == (0, 0, 0), ( - f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}" - ) + dim = len(output_padding) + assert dim in [2, 3], "Only 2d or 3d convs are supported" assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" + + if dim == 2: + assert input_tensor.is_contiguous( + memory_format=torch.channels_last + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), ( + "Please make sure both activation and weights are in the `channels_last` memory_format" + ) + # (N, C, H, W) --> (N, C, 1, H, W) + input_tensor = input_tensor.unsqueeze(2) + weight_tensor = weight_tensor.unsqueeze(2) + assert tuple(output_padding) == (0, 0), ( + f"Only (0, 0) is supported for `output_padding`, got: f{output_padding}" + ) + padding = [0, *padding] + stride = [1, *stride] + dilation = [1, *dilation] + res = _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + assert res.shape[2] == 1 + res = res.squeeze(2) + return res + else: + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), ( + "Please make sure both activation and weights are in the `channels_last_3d` memory_format" + ) + assert tuple(output_padding) == (0, 0, 0), ( + f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}" + ) + return _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + + +@implements(aten.conv3d.default) +def _(func, types, args, kwargs): + ( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + groups, + ) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1]) + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), ( + "Please make sure both activation and weights are in the `channels_last_3d` memory_format" + ) return _quantize_and_scaled_conv3d( input_tensor, weight_tensor, @@ -588,7 +650,7 @@ def _(func, types, args, kwargs): ) -@implements(aten.conv3d.default) +@implements(aten.conv2d.default) def _(func, types, args, kwargs): ( input_tensor, @@ -598,9 +660,26 @@ def _(func, types, args, kwargs): padding, dilation, groups, - ) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1]) - assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" - return _quantize_and_scaled_conv3d( + ) = fill_defaults(args, 7, [None, [1, 1], [0, 0], [1, 1], 1]) + # (N, C, H, W) --> (N, C, 1, H, W) + # memory_format of both tensors should be torch.channels_last + # and it should be preserved with unsqueeze(2) (becoming torch.channels_last_3d) + assert input_tensor.is_contiguous( + memory_format=torch.channels_last + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), ( + "Please make sure both activation and weights are in the `channels_last` memory_format" + ) + input_tensor = input_tensor.unsqueeze(2) + weight_tensor = weight_tensor.unsqueeze(2) + + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d) + + padding = [0, *padding] + stride = [1, *stride] + dilation = [1, *dilation] + res = _quantize_and_scaled_conv3d( input_tensor, weight_tensor, bias, @@ -608,6 +687,9 @@ def _(func, types, args, kwargs): padding, dilation, ) + assert res.shape[2] == 1 + res = res.squeeze(2) + return res @implements(aten.slice.Tensor) @@ -839,7 +921,6 @@ def _(func, types, args, kwargs): @implements(aten.unsqueeze.default) def _(func, types, args, kwargs): self, dim = args - assert dim == 0, f"Only dim == 0 is supported, got: {dim}" qdata = self.qdata.unsqueeze(dim=dim) scale = self.scale.unsqueeze(dim=dim) block_size = []