Skip to content

Commit 86af458

Browse files
authored
Skip quantization when channels_out / channels_in are not multiple of 16 (#3309)
Summary: The underlying fbgemm conv3d kernel for float8 only supports channels_out/channels_in are both multiples of 16 so we skip the shapes that doesn't satisfy the requirements for now, we can expand the support to do padding if needed in the future Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_skip_quant
1 parent 82eb780 commit 86af458

File tree

2 files changed

+78
-14
lines changed

2 files changed

+78
-14
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
20+
Float8Tensor,
2021
Float8WeightOnlyConfig,
2122
Granularity,
2223
PerBlock,
@@ -25,7 +26,6 @@
2526
quantize_,
2627
)
2728
from torchao.quantization.quantize_.common import KernelPreference
28-
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
2929
from torchao.quantization.utils import compute_error
3030
from torchao.testing.utils import TorchAOIntegrationTestCase
3131
from torchao.utils import (
@@ -329,14 +329,13 @@ def _test_fp8_matmul_model(
329329
@unittest.skipIf(
330330
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
331331
)
332+
@unittest.skipIf(
333+
not _is_fbgemm_gpu_genai_available(),
334+
"Requires fbgemm_gpu_genai to be installed",
335+
)
332336
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
333337
@common_utils.parametrize("compile", [True, False])
334-
@common_utils.parametrize("granularity", [PerTensor()])
335338
@common_utils.parametrize("inference_mode", [True, False])
336-
@common_utils.parametrize(
337-
"kernel_preference",
338-
[KernelPreference.AUTO],
339-
)
340339
# only test for 3D conv for now
341340
# Inputs are (N, C_in, C_out, D, H, W)
342341
@common_utils.parametrize(
@@ -349,19 +348,14 @@ def test_fp8_conv_variants(
349348
self,
350349
dtype: torch.dtype,
351350
compile: bool,
352-
granularity,
353351
inference_mode: bool,
354352
kernel_preference: KernelPreference,
355353
sizes: Tuple,
356354
):
357-
if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()):
358-
return unittest.skip(
359-
"Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
360-
"fbgemm kernel preference test"
361-
)
362-
363-
dim = 3
355+
granularity = PerTensor()
356+
kernel_preference = KernelPreference.AUTO
364357
N, C_in, C_out, D, H, W = sizes
358+
dim = 3
365359
kernel_size = 3
366360

367361
# Note: this is channel last memory format
@@ -404,6 +398,69 @@ def test_fp8_conv_variants(
404398
f"Quantization error is too high got a SQNR of {error}"
405399
)
406400

401+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
402+
@unittest.skipIf(
403+
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
404+
)
405+
@unittest.skipIf(
406+
not _is_fbgemm_gpu_genai_available(),
407+
"Requires fbgemm_gpu_genai to be installed",
408+
)
409+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
410+
# only test for 3D conv for now
411+
# Inputs are (N, C_in, C_out, D, H, W)
412+
@common_utils.parametrize(
413+
"sizes",
414+
[
415+
(4, 12, 64, 32, 32, 32),
416+
(4, 16, 12, 32, 32, 32),
417+
],
418+
)
419+
def test_fp8_conv_skip_quant(
420+
self,
421+
dtype: torch.dtype,
422+
sizes: Tuple,
423+
):
424+
"""Some shapes are not supported so we won't quantize the module
425+
Specifically, we skip quantization when C_in or C_out is not a multiple of 16
426+
"""
427+
granularity = PerTensor()
428+
kernel_preference = KernelPreference.AUTO
429+
N, C_in, C_out, D, H, W = sizes
430+
dim = 3
431+
kernel_size = 3
432+
433+
# Note: this is channel last memory format
434+
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
435+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
436+
# Create a linear layer with bfloat16 dtype
437+
model = ToyConvModel(
438+
dim,
439+
C_in,
440+
C_out,
441+
kernel_size,
442+
bias=False,
443+
padding=0,
444+
dtype=dtype,
445+
device="cuda",
446+
).eval()
447+
448+
quantized_model = copy.deepcopy(model)
449+
450+
config = Float8DynamicActivationFloat8WeightConfig(
451+
granularity=granularity,
452+
kernel_preference=kernel_preference,
453+
)
454+
455+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
456+
457+
quantize_(quantized_model, config, filter_fn=_is_conv3d)
458+
assert not isinstance(quantized_model.conv.weight, Float8Tensor)
459+
460+
output_original = model(input_tensor)
461+
output_quantized = quantized_model(input_tensor)
462+
self.assertEqual(output_original, output_quantized)
463+
407464
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
408465
@unittest.skipIf(
409466
not is_sm_at_least_90(),

torchao/quantization/quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18211821
assert isinstance(activation_granularity, PerTensor) and isinstance(
18221822
weight_granularity, PerTensor
18231823
), "5D tensor only supports per tensor activation and weight quantization"
1824+
1825+
# weight dim: (C_out, C_in, K1, K2, K3)
1826+
# skip quantization when either C_out or C_in
1827+
# is not a multiple of 16
1828+
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
1829+
return weight
1830+
18241831
elif not _fp8_mm_compat(weight):
18251832
# TODO(future PR): this should really throw an exception instead of silently
18261833
# not doing what the user asked

0 commit comments

Comments
 (0)