Skip to content

Commit 9cf37be

Browse files
committed
Add per tensor fp8 conv2d support
Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
1 parent 86af458 commit 9cf37be

File tree

3 files changed

+218
-46
lines changed

3 files changed

+218
-46
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(
8686
)
8787
if dim == 3:
8888
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
89+
elif dim == 2:
90+
self.conv = self.conv.to(memory_format=torch.channels_last)
8991

9092
def forward(self, x):
9193
return self.conv(x)
@@ -336,33 +338,43 @@ def _test_fp8_matmul_model(
336338
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
337339
@common_utils.parametrize("compile", [True, False])
338340
@common_utils.parametrize("inference_mode", [True, False])
339-
# only test for 3D conv for now
340-
# Inputs are (N, C_in, C_out, D, H, W)
341+
# test for 2D/3D conv
342+
# Inputs are (N, C_in, C_out, (D, H, W) or
343+
# (N, C_in, C_out, (H, W)
341344
@common_utils.parametrize(
342345
"sizes",
343346
[
344-
(4, 16, 64, 32, 32, 32),
347+
(4, 16, 64, (32, 32, 32)),
348+
(4, 16, 64, (32, 32)),
345349
],
346350
)
347351
def test_fp8_conv_variants(
348352
self,
349353
dtype: torch.dtype,
350354
compile: bool,
351355
inference_mode: bool,
352-
kernel_preference: KernelPreference,
353356
sizes: Tuple,
354357
):
358+
torch.compiler.reset()
355359
granularity = PerTensor()
356360
kernel_preference = KernelPreference.AUTO
357-
N, C_in, C_out, D, H, W = sizes
358-
dim = 3
361+
362+
N, C_in, C_out, spatial_dims = sizes
363+
dim = len(spatial_dims)
364+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
365+
assert dim in convs, f"Unsupported dim: {dim}"
366+
conv_class = convs[dim]
367+
359368
kernel_size = 3
360369

361370
# Note: this is channel last memory format
362-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
363-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
371+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
372+
if dim == 3:
373+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
374+
else:
375+
assert dim == 2
376+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
364377

365-
# Create a linear layer with bfloat16 dtype
366378
model = ToyConvModel(
367379
dim,
368380
C_in,
@@ -381,9 +393,9 @@ def test_fp8_conv_variants(
381393
kernel_preference=kernel_preference,
382394
)
383395

384-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
396+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
385397

386-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
398+
quantize_(quantized_model, config, filter_fn=_is_conv)
387399

388400
if compile:
389401
quantized_model = torch.compile(quantized_model, fullgraph=True)
@@ -407,13 +419,16 @@ def test_fp8_conv_variants(
407419
"Requires fbgemm_gpu_genai to be installed",
408420
)
409421
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
410-
# only test for 3D conv for now
411-
# Inputs are (N, C_in, C_out, D, H, W)
422+
# test for 2D/3D conv
423+
# Inputs are (N, C_in, C_out, (D, H, W) or
424+
# (N, C_in, C_out, (H, W)
412425
@common_utils.parametrize(
413426
"sizes",
414427
[
415-
(4, 12, 64, 32, 32, 32),
416-
(4, 16, 12, 32, 32, 32),
428+
(4, 12, 64, (32, 32, 32)),
429+
(4, 16, 12, (32, 32, 32)),
430+
(4, 12, 64, (32, 32)),
431+
(4, 16, 12, (32, 32)),
417432
],
418433
)
419434
def test_fp8_conv_skip_quant(
@@ -426,14 +441,23 @@ def test_fp8_conv_skip_quant(
426441
"""
427442
granularity = PerTensor()
428443
kernel_preference = KernelPreference.AUTO
429-
N, C_in, C_out, D, H, W = sizes
430-
dim = 3
444+
445+
N, C_in, C_out, spatial_dims = sizes
446+
447+
dim = len(spatial_dims)
448+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
449+
assert dim in convs, f"Unsupported dim: {dim}"
450+
conv_class = convs[dim]
451+
431452
kernel_size = 3
432453

433454
# Note: this is channel last memory format
434-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
435-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
436-
# Create a linear layer with bfloat16 dtype
455+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
456+
if dim == 3:
457+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
458+
else:
459+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
460+
437461
model = ToyConvModel(
438462
dim,
439463
C_in,
@@ -452,9 +476,9 @@ def test_fp8_conv_skip_quant(
452476
kernel_preference=kernel_preference,
453477
)
454478

455-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
479+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
456480

457-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
481+
quantize_(quantized_model, config, filter_fn=_is_conv)
458482
assert not isinstance(quantized_model.conv.weight, Float8Tensor)
459483

460484
output_original = model(input_tensor)
@@ -793,7 +817,6 @@ def test_index_select(self):
793817
],
794818
)
795819
def test_unsqueeze_operation(self, granularity, sizes):
796-
"""Test aten.unsqueeze.default operation on Float8Tensor"""
797820
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
798821
dtype = torch.bfloat16
799822
device = "cuda"
@@ -806,7 +829,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
806829
original_weight = linear.weight
807830
original_shape = original_weight.shape
808831

809-
# Test unsqueeze operation at dim=0 (only supported dimension)
832+
# Test unsqueeze operation at dim=0
810833
unsqueezed_weight = original_weight.unsqueeze(0)
811834

812835
# Verify the unsqueezed tensor has correct shape
@@ -848,22 +871,84 @@ def test_unsqueeze_operation(self, granularity, sizes):
848871

849872
self.assertEqual(unsqueezed_dequant, expected_dequant)
850873

851-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
852-
def test_unsqueeze_error_cases(self, granularity):
853-
"""Test error cases for aten.unsqueeze.default operation"""
874+
def test_unsqueeze_conv2d_weight(self):
875+
granularity = PerTensor()
854876
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
855877
dtype = torch.bfloat16
856878
device = "cuda"
879+
N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32)
880+
dim = len(spatial_dims)
881+
kernel_size = 3
857882

858-
# Create a linear layer and quantize it
859-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
860-
quantize_(linear, config)
883+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device)
884+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
885+
model = ToyConvModel(
886+
dim,
887+
C_in,
888+
C_out,
889+
kernel_size,
890+
bias=False,
891+
padding=0,
892+
dtype=dtype,
893+
device=device,
894+
).eval()
895+
896+
quantized_model = copy.deepcopy(model)
897+
898+
config = Float8DynamicActivationFloat8WeightConfig(
899+
granularity=granularity,
900+
)
901+
902+
_is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d)
861903

862-
weight = linear.weight
904+
quantize_(quantized_model, config, filter_fn=_is_conv)
863905

864-
# Test that unsqueezing on unsupported dimensions raises an error
865-
with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"):
866-
weight.unsqueeze(1) # dim=1 should not be supported
906+
original_weight = quantized_model.conv.weight
907+
original_shape = original_weight.shape
908+
909+
# Test unsqueeze operation at dim=2
910+
unsqueezed_weight = original_weight.unsqueeze(2)
911+
912+
# Verify the unsqueezed tensor has correct shape
913+
original_shape_list = list(original_shape)
914+
expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:]
915+
scale_shape_list = list(original_weight.scale.shape)
916+
expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:]
917+
918+
self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape))
919+
# Verify qdata and scale shapes
920+
expected_qdata_shape = expected_shape
921+
922+
self.assertEqual(
923+
unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape)
924+
)
925+
self.assertEqual(
926+
unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape)
927+
)
928+
929+
# Verify block_size is correctly updated
930+
expected_block_size = []
931+
for i in range(len(expected_shape)):
932+
expected_block_size.append(expected_shape[i] // expected_scale_shape[i])
933+
934+
self.assertEqual(unsqueezed_weight.block_size, expected_block_size)
935+
936+
# Test that metadata is preserved
937+
self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config)
938+
self.assertEqual(
939+
unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs
940+
)
941+
self.assertEqual(
942+
unsqueezed_weight.kernel_preference, original_weight.kernel_preference
943+
)
944+
self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype)
945+
946+
# Test numerical correctness
947+
original_dequant = original_weight.dequantize()
948+
unsqueezed_dequant = unsqueezed_weight.dequantize()
949+
expected_dequant = original_dequant.unsqueeze(2)
950+
951+
self.assertEqual(unsqueezed_dequant, expected_dequant)
867952

868953
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
869954
@common_utils.parametrize("slice_dim", [0, 1, 2])

torchao/quantization/quant_api.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,13 +1816,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18161816
_check_hardware_support(granularity)
18171817
activation_granularity, weight_granularity = granularity
18181818

1819-
if weight.dim() == 5:
1820-
# weights for conv3d
1819+
# Note: right now we assume it's weights of conv2d and conv3d purely based
1820+
# on the dimension of weight, currently there is no conflict with linear 2d
1821+
# and moe weights 3d
1822+
# if we need to support conv1d, which also has 3d weight, we may have to
1823+
# pass around the module as well to distinguish between conv1d and 3d moe weight
1824+
if weight.dim() in [4, 5]:
1825+
# weights for conv2d or 3d
18211826
assert isinstance(activation_granularity, PerTensor) and isinstance(
18221827
weight_granularity, PerTensor
1823-
), "5D tensor only supports per tensor activation and weight quantization"
1828+
), "4D/5D tensor only supports per tensor activation and weight quantization"
18241829

1825-
# weight dim: (C_out, C_in, K1, K2, K3)
1830+
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1831+
# conv2d weight dim: (C_out, C_in, K1, K2)
18261832
# skip quantization when either C_out or C_in
18271833
# is not a multiple of 16
18281834
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:

0 commit comments

Comments
 (0)