Skip to content

Commit 258387a

Browse files
authored
Add per tensor fp8 quantization support for conv3d (#3215)
Add per tensor fp8 quantization support conv3d Summary: att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig API: ``` config = Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ) _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) quantize_(quantized_model, config, filter_fn=_is_conv3d) ``` Test Plan: pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
1 parent 1e473ed commit 258387a

File tree

4 files changed

+230
-2
lines changed

4 files changed

+230
-2
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_is_fbgemm_gpu_genai_available,
3131
is_sm_at_least_89,
3232
is_sm_at_least_90,
33+
is_sm_at_least_100,
3334
torch_version_at_least,
3435
)
3536

@@ -49,6 +50,28 @@ def forward(self, x):
4950
return x
5051

5152

53+
class ToyConvModel(torch.nn.Module):
54+
def __init__(
55+
self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device
56+
):
57+
super().__init__()
58+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
59+
self.conv = convs[dim](
60+
in_channels,
61+
out_channels,
62+
kernel_size,
63+
bias=bias,
64+
padding=padding,
65+
dtype=dtype,
66+
device=device,
67+
)
68+
if dim == 3:
69+
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
70+
71+
def forward(self, x):
72+
return self.conv(x)
73+
74+
5275
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5376
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
5477
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -148,6 +171,85 @@ def test_fp8_linear_variants(
148171
f"Quantization error is too high got a SQNR of {error}"
149172
)
150173

174+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
175+
@unittest.skipIf(
176+
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
177+
)
178+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
179+
@common_utils.parametrize("compile", [True, False])
180+
@common_utils.parametrize("granularity", [PerTensor()])
181+
@common_utils.parametrize("inference_mode", [True, False])
182+
@common_utils.parametrize(
183+
"kernel_preference",
184+
[KernelPreference.AUTO],
185+
)
186+
# only test for 3D conv for now
187+
# Inputs are (N, C_in, C_out, D, H, W)
188+
@common_utils.parametrize(
189+
"sizes",
190+
[
191+
(4, 16, 64, 32, 32, 32),
192+
],
193+
)
194+
def test_fp8_conv_variants(
195+
self,
196+
dtype: torch.dtype,
197+
compile: bool,
198+
granularity,
199+
inference_mode: bool,
200+
kernel_preference: KernelPreference,
201+
sizes: Tuple,
202+
):
203+
if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()):
204+
return unittest.skip(
205+
"Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
206+
"fbgemm kernel preference test"
207+
)
208+
209+
dim = 3
210+
N, C_in, C_out, D, H, W = sizes
211+
kernel_size = 3
212+
213+
# Note: this is channel last memory format
214+
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
215+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
216+
217+
# Create a linear layer with bfloat16 dtype
218+
model = ToyConvModel(
219+
dim,
220+
C_in,
221+
C_out,
222+
kernel_size,
223+
bias=False,
224+
padding=0,
225+
dtype=dtype,
226+
device="cuda",
227+
).eval()
228+
229+
quantized_model = copy.deepcopy(model)
230+
231+
config = Float8DynamicActivationFloat8WeightConfig(
232+
granularity=granularity,
233+
kernel_preference=kernel_preference,
234+
)
235+
236+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
237+
238+
quantize_(quantized_model, config, filter_fn=_is_conv3d)
239+
240+
if compile:
241+
quantized_model = torch.compile(quantized_model, fullgraph=True)
242+
243+
inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext()
244+
with inference_mode_ctx:
245+
output_original = model(input_tensor)
246+
output_quantized = quantized_model(input_tensor)
247+
248+
error = compute_error(output_original, output_quantized)
249+
assert compute_error(output_original, output_quantized) > 20, (
250+
f"Quantization error is too high got a SQNR of {error}"
251+
)
252+
151253
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
152254
@unittest.skipIf(
153255
not is_sm_at_least_90(),

torchao/quantization/quant_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
17971797
_check_hardware_support(granularity)
17981798
activation_granularity, weight_granularity = granularity
17991799

1800-
if not _fp8_mm_compat(weight):
1800+
if weight.dim() == 5:
1801+
# weights for conv3d
1802+
assert isinstance(activation_granularity, PerTensor) and isinstance(
1803+
weight_granularity, PerTensor
1804+
), "5D tensor only supports per tensor activation and weight quantization"
1805+
elif not _fp8_mm_compat(weight):
18011806
# TODO(future PR): this should really throw an exception instead of silently
18021807
# not doing what the user asked
18031808
return weight

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_is_fbgemm_gpu_genai_available,
4040
fill_defaults,
4141
is_sm_at_least_90,
42+
is_sm_at_least_100,
4243
)
4344

4445
__all__ = [
@@ -261,7 +262,7 @@ def _(func, types, args, kwargs):
261262
)
262263

263264
act_quant_kwargs = weight_tensor.act_quant_kwargs
264-
# quantizing activation, if `act_quant_kwargs` is specified
265+
# quantize activation, if `act_quant_kwargs` is specified
265266
if act_quant_kwargs is not None:
266267
input_tensor = _choose_quant_func_and_quantize_tensor(
267268
input_tensor, act_quant_kwargs
@@ -418,6 +419,125 @@ def _(func, types, args, kwargs):
418419
return res
419420

420421

422+
def _quantize_and_scaled_conv3d(
423+
input_tensor,
424+
weight_tensor,
425+
bias,
426+
stride,
427+
padding,
428+
dilation,
429+
):
430+
assert isinstance(weight_tensor, Float8Tensor), (
431+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
432+
)
433+
434+
assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
435+
"Only support 3D conv currently"
436+
)
437+
assert _is_fbgemm_gpu_genai_available(), (
438+
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
439+
)
440+
act_quant_kwargs = weight_tensor.act_quant_kwargs
441+
# quantize activation, if `act_quant_kwargs` is specified
442+
if act_quant_kwargs is not None:
443+
input_tensor = _choose_quant_func_and_quantize_tensor(
444+
input_tensor, act_quant_kwargs
445+
)
446+
447+
if isinstance(input_tensor, Float8Tensor):
448+
kernel_choice = None
449+
if weight_tensor.kernel_preference == KernelPreference.AUTO:
450+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_100():
451+
kernel_choice = "fbgemm"
452+
else:
453+
raise NotImplementedError(
454+
f"No available kernel choice for {weight_tensor.kernel_preference}"
455+
)
456+
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
457+
kernel_choice = "fbgemm"
458+
else:
459+
raise NotImplementedError(
460+
f"No available kernel choice for {weight_tensor.kernel_preference}"
461+
)
462+
463+
assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently"
464+
# move C_in to last dim
465+
# after permute: (N, D, H, W, C_in)
466+
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])
467+
468+
# move C_in to last dim
469+
# after permute: (C_out, K1, K2, K3, C_in)
470+
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
471+
472+
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
473+
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
474+
)
475+
476+
act_scale = input_tensor.scale
477+
weight_scale = weight_tensor.scale
478+
output = torch.ops.fbgemm.f8f8bf16_conv(
479+
act_qdata,
480+
weight_qdata,
481+
act_scale * weight_scale,
482+
padding,
483+
stride,
484+
dilation,
485+
)
486+
# output shape after permute: N, C_out, D_out, H_out, W_out
487+
output = output.permute([0, 4, 1, 2, 3])
488+
return output
489+
490+
491+
@implements(aten.convolution.default)
492+
def _(func, types, args, kwargs):
493+
(
494+
input_tensor,
495+
weight_tensor,
496+
bias,
497+
stride,
498+
padding,
499+
dilation,
500+
transposed,
501+
output_padding,
502+
groups,
503+
) = args
504+
assert not transposed, "transposed conv is not supported currently"
505+
assert tuple(output_padding) == (0, 0, 0), (
506+
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
507+
)
508+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
509+
return _quantize_and_scaled_conv3d(
510+
input_tensor,
511+
weight_tensor,
512+
bias,
513+
stride,
514+
padding,
515+
dilation,
516+
)
517+
518+
519+
@implements(aten.conv3d.default)
520+
def _(func, types, args, kwargs):
521+
(
522+
input_tensor,
523+
weight_tensor,
524+
bias,
525+
stride,
526+
padding,
527+
dilation,
528+
groups,
529+
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
530+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
531+
return _quantize_and_scaled_conv3d(
532+
input_tensor,
533+
weight_tensor,
534+
bias,
535+
stride,
536+
padding,
537+
dilation,
538+
)
539+
540+
421541
@implements(aten.slice.Tensor)
422542
def _(func, types, args, kwargs):
423543
"""Supports slicing for 1d, 2d, and 3d tensors

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"is_MI300",
3333
"is_sm_at_least_89",
3434
"is_sm_at_least_90",
35+
"is_sm_at_least_100",
3536
"is_package_at_least",
3637
"DummyModule",
3738
# Deprecated

0 commit comments

Comments
 (0)