Skip to content

Conversation

@agolajko
Copy link

@agolajko agolajko commented Nov 6, 2025

Summary

As discussed with @danielvegamyhre following #3290 created benchmark scripts for each quantization kernel from kernels.py

This compares the kernels' performance against a naive torch implementation

For triton_fp8_blockwise_act_quant_lhs I was able to use the naive torch implementation in the kernels.py file, for the others I wrote new ones

Bench results on a 5090

Python: 3.13.8 (main, Oct 8 2025, 08:53:25) [GCC 13.3.0]
PyTorch: 2.9.0+cu128
CUDA: 12.8
CuDNN: 91002
570.153.02, NVIDIA GeForce RTX 5090, 12.0
Ubuntu 24.04.3 LTS

LHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   torch_us |   triton_us | speedup   |   torch_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |       9.73 |        6.14 | 1.58x     |        653.5 |        1034.7 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      18.43 |       10.66 | 1.73x     |        689.8 |        1193.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |      26.59 |       22.53 | 1.18x     |        956.2 |        1128.7 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |      47.07 |       40.99 | 1.15x     |       1080.4 |        1240.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |      86.43 |       75.78 | 1.14x     |       1176.8 |        1342.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 16384x4096           |          128 |     180.83 |      145.38 | 1.24x     |       1124.9 |        1399.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 32768x4096           |          128 |     436.22 |      286.69 | 1.52x     |        932.7 |        1419.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 65536x4096           |          128 |     870.4  |      571.47 | 1.52x     |        934.9 |        1423.9 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 131072x4096          |          128 |    1738.75 |     1133.57 | 1.53x     |        936   |        1435.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

RHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   torch_us |   triton_us | speedup   |   torch_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      40.96 |        8.19 | 5.00x     |        155.2 |         776   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      71.68 |       10.24 | 7.00x     |        177.4 |        1241.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     133.12 |       16.38 | 8.13x     |        191   |        1552   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     259.07 |       30.72 | 8.43x     |        196.3 |        1655.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     516.67 |       61.44 | 8.41x     |        196.9 |        1655.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 16384x4096           |          128 |    1032.59 |      124.93 | 8.27x     |        197   |        1628.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 32768x4096           |          128 |    2057.22 |      256    | 8.04x     |        197.8 |        1589.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 65536x4096           |          128 |    4109.76 |      520.61 | 7.89x     |        198   |        1563   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 131072x4096          |          128 |    8208.8  |     1055.23 | 7.78x     |        198.2 |        1542.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

Transposed LHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   torch_us |   triton_us | speedup   |   torch_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      40.96 |        8.19 | 5.00x     |        155.2 |         776   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      71.68 |       10.46 | 6.85x     |        177.4 |        1215   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     133.12 |       17.9  | 7.44x     |        191   |        1420.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     260.1  |       32.35 | 8.04x     |        195.5 |        1572   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     520.16 |       64.93 | 8.01x     |        195.5 |        1566.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 16384x4096           |          128 |    1037.54 |      126.98 | 8.17x     |        196.1 |        1602.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 32768x4096           |          128 |    2067.46 |      258.46 | 8.00x     |        196.8 |        1574.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 65536x4096           |          128 |    4126.72 |      522.24 | 7.90x     |        197.2 |        1558.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 131072x4096          |          128 |    8251.39 |     1053.7  | 7.83x     |        197.2 |        1544.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

RHS Weights:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, N)   |   block_size |   torch_us |   triton_us | speedup   |   torch_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      14.46 |        6.14 | 2.35x     |        435   |        1024.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      73.82 |        6.56 | 11.25x    |        170.5 |        1918.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     136.78 |       10.24 | 13.36x    |        184   |        2457.8 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     268.29 |       22.53 | 11.91x    |        187.6 |        2234.4 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     528.38 |       55.3  | 9.56x     |        190.5 |        1820.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 16384x4096           |          128 |    1045.09 |      118.78 | 8.80x     |        192.7 |        1695   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 32768x4096           |          128 |    2076.06 |      248.83 | 8.34x     |        194   |        1618.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 65536x4096           |          128 |    4138.53 |      508.93 | 8.13x     |        194.6 |        1582.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 131072x4096          |          128 |    8257.54 |     1026.85 | 8.04x     |        195.1 |        1568.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

Transposed RHS Weights:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, N)   |   block_size |   torch_us |   triton_us | speedup   |   torch_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      10.24 |        6.14 | 1.67x     |        614.5 |        1024.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      65.95 |        6.56 | 10.05x    |        190.8 |        1918.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     121.25 |       10.24 | 11.84x    |        207.6 |        2457.8 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     239.2  |       28.67 | 8.34x     |        210.4 |        1755.6 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     476.32 |       62.86 | 7.58x     |        211.4 |        1601.4 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 16384x4096           |          128 |     949.76 |      130.66 | 7.27x     |        212   |        1541   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 32768x4096           |          128 |    1897.47 |      272.19 | 6.97x     |        212.2 |        1479.4 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 65536x4096           |          128 |    3793.44 |      552.45 | 6.87x     |        212.3 |        1457.8 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 131072x4096          |          128 |    7582.26 |     1107.52 | 6.85x     |        212.4 |        1454.4 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3306

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link

meta-cla bot commented Nov 6, 2025

Hi @agolajko!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link

meta-cla bot commented Nov 6, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 6, 2025
@danielvegamyhre
Copy link
Contributor

Thanks for working on this @agolajko, can you run the benchmarks and include them in the PR description? or do you not have access to a b200 gpu?

@agolajko
Copy link
Author

agolajko commented Nov 6, 2025

Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well

@danielvegamyhre
Copy link
Contributor

Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well

ok any blackwell gpu is fine

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

@danielvegamyhre added the results as well as some of the potentially relevant sys info

@vkuzo
Copy link
Contributor

vkuzo commented Nov 7, 2025

if this extends to gemms, I'd recommend also benchmarking F.scaled_mm. Note that you need a nightly with pytorch/pytorch#166752 for correct numerics.

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

Thanks for the suggestion @vkuzo, for now just keeping this PR to benching the quantizations

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

Btw @vkuzo what are the other Gemms benchmarking scripts needed for the FP8 training given there are already two (bench_1x128_128x128 and bench_1x128_128x1) here: https://github.com/pytorch/ao/tree/main/benchmarks/prototype/blockwise_fp8_training

@danielvegamyhre
Copy link
Contributor

Btw @vkuzo what are the other Gemms benchmarking scripts needed for the FP8 training given there are already two (bench_1x128_128x128 and bench_1x128_128x1) here: https://github.com/pytorch/ao/tree/main/benchmarks/prototype/blockwise_fp8_training

these 2 types of fp8 blockwise gemms are all we need for training:

  • out = input @ weight.t() is 1x128 @ 128x128 scaling
  • dgrad = grad_out @ weight is 1x128 @ 128x128 scaling
  • wgrad = grad_out.t() @ input is 1x128 @ 128x1 scaling

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

@danielvegamyhre Thanks for all the comments from above, have fixed those and re benched the kernels (updated tables in the PR Summary)

The new bench times show that triton_fp8_blockwise_act_quant_lhs is actually running faster now, I expect that's bc its not being torch compiled

Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, thanks for your work on this! Almost ready to land, just a few last comments

(1024, 4096),
(2048, 4096),
(4096, 4096),
(8192, 4096),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the leading total_M dims (seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.

same for act_quant_rhs benchmarks

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, any downside to having all the above quantization benchmarks with these bigger values?

)

# Benchmark naive implementation
naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: i wouldn't call the torch native implementation "naive" per say - when using torch.compile these can sometimes be quite fast / close to speed of light. however, when that is not the case, we hand implement kernels with triton or cuda (like we've done here).

i think just replacing naive_impl => torch_impl or similar would be better

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced all mentions of naive in function and variables names with 'torch'

@@ -0,0 +1,274 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add .py suffix to this file, looks like it got removed or forgotten somehow:

benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs

M, K = config.input_shape
block_size = config.block_size

def verify_outputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in these various verify_outputs can we also validate the memory layouts are the same? i.e., check shapes and strides match.

this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance)

@danielvegamyhre danielvegamyhre added topic: not user facing Use this tag if you don't want this PR to show up in release notes training labels Nov 8, 2025
@agolajko
Copy link
Author

@danielvegamyhre updated the code with your comments, lmk what you think

Also super appreciate your help with improving the quality of this code so it's closer to the rest of the codebase!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes training

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants