-
Notifications
You must be signed in to change notification settings - Fork 364
Re: #3290 FP8 Blockwise Training Tracker, quantization benchmarks #3306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…aive torch implementation
…entation torch_blockwise_scale_act_quant_lhs from existing blockwise_fp8_training/kernels
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3306
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @agolajko! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
...arks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
Thanks for working on this @agolajko, can you run the benchmarks and include them in the PR description? or do you not have access to a b200 gpu? |
|
Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well |
ok any blackwell gpu is fine |
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
|
@danielvegamyhre added the results as well as some of the potentially relevant sys info |
|
if this extends to gemms, I'd recommend also benchmarking |
|
Thanks for the suggestion @vkuzo, for now just keeping this PR to benching the quantizations |
|
Btw @vkuzo what are the other Gemms benchmarking scripts needed for the FP8 training given there are already two (bench_1x128_128x128 and bench_1x128_128x1) here: https://github.com/pytorch/ao/tree/main/benchmarks/prototype/blockwise_fp8_training |
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py
Outdated
Show resolved
Hide resolved
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py
Outdated
Show resolved
Hide resolved
these 2 types of fp8 blockwise gemms are all we need for training:
|
|
@danielvegamyhre Thanks for all the comments from above, have fixed those and re benched the kernels (updated tables in the PR Summary) The new bench times show that |
danielvegamyhre
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, thanks for your work on this! Almost ready to land, just a few last comments
| (1024, 4096), | ||
| (2048, 4096), | ||
| (4096, 4096), | ||
| (8192, 4096), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make the leading total_M dims (seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.
same for act_quant_rhs benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, any downside to having all the above quantization benchmarks with these bigger values?
| ) | ||
|
|
||
| # Benchmark naive implementation | ||
| naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: i wouldn't call the torch native implementation "naive" per say - when using torch.compile these can sometimes be quite fast / close to speed of light. however, when that is not the case, we hand implement kernels with triton or cuda (like we've done here).
i think just replacing naive_impl => torch_impl or similar would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've replaced all mentions of naive in function and variables names with 'torch'
| @@ -0,0 +1,274 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add .py suffix to this file, looks like it got removed or forgotten somehow:
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs
| M, K = config.input_shape | ||
| block_size = config.block_size | ||
|
|
||
| def verify_outputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in these various verify_outputs can we also validate the memory layouts are the same? i.e., check shapes and strides match.
this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance)
|
@danielvegamyhre updated the code with your comments, lmk what you think Also super appreciate your help with improving the quality of this code so it's closer to the rest of the codebase! |
Summary
As discussed with @danielvegamyhre following #3290 created benchmark scripts for each quantization kernel from kernels.py
This compares the kernels' performance against a naive torch implementation
For
triton_fp8_blockwise_act_quant_lhsI was able to use the naive torch implementation in the kernels.py file, for the others I wrote new onesBench results on a 5090
Python: 3.13.8 (main, Oct 8 2025, 08:53:25) [GCC 13.3.0]
PyTorch: 2.9.0+cu128
CUDA: 12.8
CuDNN: 91002
570.153.02, NVIDIA GeForce RTX 5090, 12.0
Ubuntu 24.04.3 LTS
LHS Activation:
RHS Activation:
Transposed LHS Activation:
RHS Weights:
Transposed RHS Weights: