Skip to content

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 8, 2025

Summary:
Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1.

  1. unsqueeze both input and weight in dim 2 ( the D dimension)
  2. call fp8 conv3d op from fbgemm torch.ops.fbgemm.f8f8bf16_conv
  3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension

Test Plan:

python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight 
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3315

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (9 Unrelated Failures)

As of commit 22d7227 with merge base e8c4d09 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 8, 2025
@jerryzh168 jerryzh168 added the topic: new feature Use this tag if this PR adds a new feature label Nov 8, 2025
padding = [0, *padding]
stride = [1, *stride]
dilation = [1, *dilation]
res = _quantize_and_scaled_conv3d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this unsqueezing to turn 2d into 3d? Does fbgemm only have 3d conv kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this turns 2d conv to 3d and it's because fbgemm only supports 3d conv right now


if weight.dim() == 5:
# weights for conv3d
if weight.dim() in [4, 5]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there anything more robust we can check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't really distinguish from here whether it's linear and conv weight I think (although right now seems linear is 2d/3d and conv is 4d/5d, maybe conv1d could have 3d weight, which is an overlap with linear)

but we could potentially separate the handling of conv and linear by passing around the module as well, if this is needed in the future

I can add a comment for now

"Please make sure both activation and weights are in the `channels_last` memory_format"
)
input_tensor = input_tensor.unsqueeze(2)
weight_tensor = weight_tensor.unsqueeze(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse this code, e.g. have aten.conv2d call into aten.convolution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think that should be possible for both aten.conv2d and aten.conv3d, I can refactor that in next PR?

@jerryzh168 jerryzh168 force-pushed the add-conv2d branch 2 times, most recently from fe50ba6 to 9cf37be Compare November 10, 2025 18:09
Summary:
Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1.

1. unsqueeze both input and weight in dim 2 ( the D dimension)
2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv`
3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension

Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
@jerryzh168 jerryzh168 merged commit df01de5 into pytorch:main Nov 10, 2025
9 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants