-
Notifications
You must be signed in to change notification settings - Fork 365
Add per tensor fp8 conv2d support #3315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3315
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (9 Unrelated Failures)As of commit 22d7227 with merge base e8c4d09 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
007ea02 to
8bf4032
Compare
| padding = [0, *padding] | ||
| stride = [1, *stride] | ||
| dilation = [1, *dilation] | ||
| res = _quantize_and_scaled_conv3d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this unsqueezing to turn 2d into 3d? Does fbgemm only have 3d conv kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this turns 2d conv to 3d and it's because fbgemm only supports 3d conv right now
|
|
||
| if weight.dim() == 5: | ||
| # weights for conv3d | ||
| if weight.dim() in [4, 5]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there anything more robust we can check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't really distinguish from here whether it's linear and conv weight I think (although right now seems linear is 2d/3d and conv is 4d/5d, maybe conv1d could have 3d weight, which is an overlap with linear)
but we could potentially separate the handling of conv and linear by passing around the module as well, if this is needed in the future
I can add a comment for now
| "Please make sure both activation and weights are in the `channels_last` memory_format" | ||
| ) | ||
| input_tensor = input_tensor.unsqueeze(2) | ||
| weight_tensor = weight_tensor.unsqueeze(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse this code, e.g. have aten.conv2d call into aten.convolution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think that should be possible for both aten.conv2d and aten.conv3d, I can refactor that in next PR?
fe50ba6 to
9cf37be
Compare
Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
9cf37be to
22d7227
Compare
Summary:
Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1.
torch.ops.fbgemm.f8f8bf16_convTest Plan: