Skip to content

Commit 22d7227

Browse files
committed
Add per tensor fp8 conv2d support
Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
1 parent e8c4d09 commit 22d7227

File tree

3 files changed

+218
-46
lines changed

3 files changed

+218
-46
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
)
8888
if dim == 3:
8989
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
90+
elif dim == 2:
91+
self.conv = self.conv.to(memory_format=torch.channels_last)
9092

9193
def forward(self, x):
9294
return self.conv(x)
@@ -337,33 +339,43 @@ def _test_fp8_matmul_model(
337339
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
338340
@common_utils.parametrize("compile", [True, False])
339341
@common_utils.parametrize("inference_mode", [True, False])
340-
# only test for 3D conv for now
341-
# Inputs are (N, C_in, C_out, D, H, W)
342+
# test for 2D/3D conv
343+
# Inputs are (N, C_in, C_out, (D, H, W) or
344+
# (N, C_in, C_out, (H, W)
342345
@common_utils.parametrize(
343346
"sizes",
344347
[
345-
(4, 16, 64, 32, 32, 32),
348+
(4, 16, 64, (32, 32, 32)),
349+
(4, 16, 64, (32, 32)),
346350
],
347351
)
348352
def test_fp8_conv_variants(
349353
self,
350354
dtype: torch.dtype,
351355
compile: bool,
352356
inference_mode: bool,
353-
kernel_preference: KernelPreference,
354357
sizes: Tuple,
355358
):
359+
torch.compiler.reset()
356360
granularity = PerTensor()
357361
kernel_preference = KernelPreference.AUTO
358-
N, C_in, C_out, D, H, W = sizes
359-
dim = 3
362+
363+
N, C_in, C_out, spatial_dims = sizes
364+
dim = len(spatial_dims)
365+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
366+
assert dim in convs, f"Unsupported dim: {dim}"
367+
conv_class = convs[dim]
368+
360369
kernel_size = 3
361370

362371
# Note: this is channel last memory format
363-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
364-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
372+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
373+
if dim == 3:
374+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
375+
else:
376+
assert dim == 2
377+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
365378

366-
# Create a linear layer with bfloat16 dtype
367379
model = ToyConvModel(
368380
dim,
369381
C_in,
@@ -382,9 +394,9 @@ def test_fp8_conv_variants(
382394
kernel_preference=kernel_preference,
383395
)
384396

385-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
397+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
386398

387-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
399+
quantize_(quantized_model, config, filter_fn=_is_conv)
388400

389401
if compile:
390402
quantized_model = torch.compile(quantized_model, fullgraph=True)
@@ -408,13 +420,16 @@ def test_fp8_conv_variants(
408420
"Requires fbgemm_gpu_genai to be installed",
409421
)
410422
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
411-
# only test for 3D conv for now
412-
# Inputs are (N, C_in, C_out, D, H, W)
423+
# test for 2D/3D conv
424+
# Inputs are (N, C_in, C_out, (D, H, W) or
425+
# (N, C_in, C_out, (H, W)
413426
@common_utils.parametrize(
414427
"sizes",
415428
[
416-
(4, 12, 64, 32, 32, 32),
417-
(4, 16, 12, 32, 32, 32),
429+
(4, 12, 64, (32, 32, 32)),
430+
(4, 16, 12, (32, 32, 32)),
431+
(4, 12, 64, (32, 32)),
432+
(4, 16, 12, (32, 32)),
418433
],
419434
)
420435
def test_fp8_conv_skip_quant(
@@ -427,14 +442,23 @@ def test_fp8_conv_skip_quant(
427442
"""
428443
granularity = PerTensor()
429444
kernel_preference = KernelPreference.AUTO
430-
N, C_in, C_out, D, H, W = sizes
431-
dim = 3
445+
446+
N, C_in, C_out, spatial_dims = sizes
447+
448+
dim = len(spatial_dims)
449+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
450+
assert dim in convs, f"Unsupported dim: {dim}"
451+
conv_class = convs[dim]
452+
432453
kernel_size = 3
433454

434455
# Note: this is channel last memory format
435-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
436-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
437-
# Create a linear layer with bfloat16 dtype
456+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
457+
if dim == 3:
458+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
459+
else:
460+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
461+
438462
model = ToyConvModel(
439463
dim,
440464
C_in,
@@ -453,9 +477,9 @@ def test_fp8_conv_skip_quant(
453477
kernel_preference=kernel_preference,
454478
)
455479

456-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
480+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
457481

458-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
482+
quantize_(quantized_model, config, filter_fn=_is_conv)
459483
assert not isinstance(quantized_model.conv.weight, Float8Tensor)
460484

461485
output_original = model(input_tensor)
@@ -832,7 +856,6 @@ def test_index_select(self):
832856
],
833857
)
834858
def test_unsqueeze_operation(self, granularity, sizes):
835-
"""Test aten.unsqueeze.default operation on Float8Tensor"""
836859
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
837860
dtype = torch.bfloat16
838861
device = "cuda"
@@ -845,7 +868,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
845868
original_weight = linear.weight
846869
original_shape = original_weight.shape
847870

848-
# Test unsqueeze operation at dim=0 (only supported dimension)
871+
# Test unsqueeze operation at dim=0
849872
unsqueezed_weight = original_weight.unsqueeze(0)
850873

851874
# Verify the unsqueezed tensor has correct shape
@@ -887,22 +910,84 @@ def test_unsqueeze_operation(self, granularity, sizes):
887910

888911
self.assertEqual(unsqueezed_dequant, expected_dequant)
889912

890-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
891-
def test_unsqueeze_error_cases(self, granularity):
892-
"""Test error cases for aten.unsqueeze.default operation"""
913+
def test_unsqueeze_conv2d_weight(self):
914+
granularity = PerTensor()
893915
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
894916
dtype = torch.bfloat16
895917
device = "cuda"
918+
N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32)
919+
dim = len(spatial_dims)
920+
kernel_size = 3
896921

897-
# Create a linear layer and quantize it
898-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
899-
quantize_(linear, config)
922+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device)
923+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
924+
model = ToyConvModel(
925+
dim,
926+
C_in,
927+
C_out,
928+
kernel_size,
929+
bias=False,
930+
padding=0,
931+
dtype=dtype,
932+
device=device,
933+
).eval()
934+
935+
quantized_model = copy.deepcopy(model)
936+
937+
config = Float8DynamicActivationFloat8WeightConfig(
938+
granularity=granularity,
939+
)
940+
941+
_is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d)
900942

901-
weight = linear.weight
943+
quantize_(quantized_model, config, filter_fn=_is_conv)
902944

903-
# Test that unsqueezing on unsupported dimensions raises an error
904-
with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"):
905-
weight.unsqueeze(1) # dim=1 should not be supported
945+
original_weight = quantized_model.conv.weight
946+
original_shape = original_weight.shape
947+
948+
# Test unsqueeze operation at dim=2
949+
unsqueezed_weight = original_weight.unsqueeze(2)
950+
951+
# Verify the unsqueezed tensor has correct shape
952+
original_shape_list = list(original_shape)
953+
expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:]
954+
scale_shape_list = list(original_weight.scale.shape)
955+
expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:]
956+
957+
self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape))
958+
# Verify qdata and scale shapes
959+
expected_qdata_shape = expected_shape
960+
961+
self.assertEqual(
962+
unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape)
963+
)
964+
self.assertEqual(
965+
unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape)
966+
)
967+
968+
# Verify block_size is correctly updated
969+
expected_block_size = []
970+
for i in range(len(expected_shape)):
971+
expected_block_size.append(expected_shape[i] // expected_scale_shape[i])
972+
973+
self.assertEqual(unsqueezed_weight.block_size, expected_block_size)
974+
975+
# Test that metadata is preserved
976+
self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config)
977+
self.assertEqual(
978+
unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs
979+
)
980+
self.assertEqual(
981+
unsqueezed_weight.kernel_preference, original_weight.kernel_preference
982+
)
983+
self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype)
984+
985+
# Test numerical correctness
986+
original_dequant = original_weight.dequantize()
987+
unsqueezed_dequant = unsqueezed_weight.dequantize()
988+
expected_dequant = original_dequant.unsqueeze(2)
989+
990+
self.assertEqual(unsqueezed_dequant, expected_dequant)
906991

907992
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
908993
@common_utils.parametrize("slice_dim", [0, 1, 2])

torchao/quantization/quant_api.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,13 +1816,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18161816
_check_hardware_support(granularity)
18171817
activation_granularity, weight_granularity = granularity
18181818

1819-
if weight.dim() == 5:
1820-
# weights for conv3d
1819+
# Note: right now we assume it's weights of conv2d and conv3d purely based
1820+
# on the dimension of weight, currently there is no conflict with linear 2d
1821+
# and moe weights 3d
1822+
# if we need to support conv1d, which also has 3d weight, we may have to
1823+
# pass around the module as well to distinguish between conv1d and 3d moe weight
1824+
if weight.dim() in [4, 5]:
1825+
# weights for conv2d or 3d
18211826
assert isinstance(activation_granularity, PerTensor) and isinstance(
18221827
weight_granularity, PerTensor
1823-
), "5D tensor only supports per tensor activation and weight quantization"
1828+
), "4D/5D tensor only supports per tensor activation and weight quantization"
18241829

1825-
# weight dim: (C_out, C_in, K1, K2, K3)
1830+
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1831+
# conv2d weight dim: (C_out, C_in, K1, K2)
18261832
# skip quantization when either C_out or C_in
18271833
# is not a multiple of 16
18281834
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:

0 commit comments

Comments
 (0)