-
Notifications
You must be signed in to change notification settings - Fork 203
Optimize NVFP4 Triton kernel #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: mxin <mxin@nvidia.com>
f58e420 to
0cf5fb6
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #533 +/- ##
=======================================
Coverage 74.57% 74.57%
=======================================
Files 183 183
Lines 18412 18412
=======================================
Hits 13730 13730
Misses 4682 4682 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: mxin <mxin@nvidia.com>
|
Thanks @mxinO. Do you have unittest cover this change? |
realAsma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
| global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) | ||
|
|
||
| # Load input data | ||
| x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, it looks like the old version has the proper mask in tl.load and tl.store. Why does it cause the nvbug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the illegal memory is hard to debug, because the error message never directs to the correct position. I didn't find the root cause actually, just guess it was the addressing issue. So changed the way to load and it's fixed. That bug is a rare case, It's never seen before.
We have tests covering the triton kernel's correctness. |
## What does this PR do?
**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
**Overview:**
1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal
memory access in rare cases.
2. Now the tile rows and columns can be specified separately.
3. Moving data type cast to kernel to save memory for bf16/fp16 inputs.
4. I did a benchmark comparing with the old kernel on H100 and B200, it
has significant speed-up for medium and large size inputs (B200: 1.4x -
2x, H100: 1.7x - 2.8x)
H100:
```shell
Shape: 512x512
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.32 µs
new kernel: 38.49 µs
speedup: 0.92x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.48 µs
new kernel: 44.78 µs
speedup: 0.97x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.25 µs
new kernel: 43.69 µs
speedup: 0.99x
Shape: 1024x1024
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 36.03 µs
new kernel: 38.17 µs
speedup: 0.94x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 44.24 µs
new kernel: 43.78 µs
speedup: 1.01x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.77 µs
new kernel: 43.61 µs
speedup: 1.00x
Shape: 4096x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 87.02 µs
new kernel: 80.88 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 116.12 µs
new kernel: 65.80 µs
speedup: 1.76x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 114.39 µs
new kernel: 65.30 µs
speedup: 1.75x
Shape: 8192x8192
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 237.29 µs
new kernel: 219.42 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 349.76 µs
new kernel: 138.66 µs
speedup: 2.52x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 341.89 µs
new kernel: 136.91 µs
speedup: 2.50x
Shape: 8192x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 338.65 µs
new kernel: 312.70 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 505.63 µs
new kernel: 188.24 µs
speedup: 2.69x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 492.97 µs
new kernel: 186.88 µs
speedup: 2.64x
Shape: 12288x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 490.25 µs
new kernel: 451.16 µs
speedup: 1.09x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 736.04 µs
new kernel: 261.94 µs
speedup: 2.81x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 717.64 µs
new kernel: 257.82 µs
speedup: 2.78x
Shape: 32x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.61 µs
new kernel: 38.23 µs
speedup: 0.93x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.00 µs
new kernel: 43.85 µs
speedup: 0.98x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 42.83 µs
new kernel: 44.13 µs
speedup: 0.97x
Shape: 1024x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 38.12 µs
new kernel: 41.28 µs
speedup: 0.92x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.80 µs
new kernel: 45.96 µs
speedup: 1.15x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 51.56 µs
new kernel: 45.30 µs
speedup: 1.14x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.70 µs
new kernel: 38.03 µs
speedup: 1.10x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.95 µs
new kernel: 44.14 µs
speedup: 1.20x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 52.57 µs
new kernel: 44.38 µs
speedup: 1.18x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.70 µs
new kernel: 38.03 µs
speedup: 1.10x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.95 µs
new kernel: 44.14 µs
speedup: 1.20x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 52.57 µs
new kernel: 44.38 µs
speedup: 1.18x
Shape: 128x8200
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 48.03 µs
new kernel: 38.38 µs
speedup: 1.25x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 60.54 µs
new kernel: 44.51 µs
speedup: 1.36x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 60.08 µs
new kernel: 43.59 µs
speedup: 1.38x
```
B200:
```shell
Shape: 512x512
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 34.63 µs
new kernel: 32.80 µs
speedup: 1.06x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 42.26 µs
new kernel: 40.92 µs
speedup: 1.03x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 41.38 µs
new kernel: 39.30 µs
speedup: 1.05x
Shape: 1024x1024
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.07 µs
new kernel: 33.93 µs
speedup: 1.03x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.57 µs
new kernel: 39.55 µs
speedup: 1.10x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.72 µs
new kernel: 38.96 µs
speedup: 1.12x
Shape: 4096x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 71.64 µs
new kernel: 58.66 µs
speedup: 1.22x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 81.67 µs
new kernel: 57.98 µs
speedup: 1.41x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 82.19 µs
new kernel: 57.56 µs
speedup: 1.43x
Shape: 8192x8192
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 176.85 µs
new kernel: 135.78 µs
speedup: 1.30x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 217.99 µs
new kernel: 121.84 µs
speedup: 1.79x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 215.47 µs
new kernel: 117.41 µs
speedup: 1.84x
Shape: 8192x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 248.18 µs
new kernel: 186.64 µs
speedup: 1.33x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 306.25 µs
new kernel: 163.28 µs
speedup: 1.88x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 303.06 µs
new kernel: 157.59 µs
speedup: 1.92x
Shape: 12288x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 354.23 µs
new kernel: 262.99 µs
speedup: 1.35x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 439.44 µs
new kernel: 224.71 µs
speedup: 1.96x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 434.23 µs
new kernel: 217.62 µs
speedup: 2.00x
Shape: 32x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.90 µs
new kernel: 34.88 µs
speedup: 1.03x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.77 µs
new kernel: 41.49 µs
speedup: 1.05x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.22 µs
new kernel: 41.79 µs
speedup: 1.03x
Shape: 1024x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 37.37 µs
new kernel: 37.84 µs
speedup: 0.99x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 49.69 µs
new kernel: 43.85 µs
speedup: 1.13x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 48.93 µs
new kernel: 44.31 µs
speedup: 1.10x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.83 µs
new kernel: 35.44 µs
speedup: 1.18x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 53.23 µs
new kernel: 40.64 µs
speedup: 1.31x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 54.39 µs
new kernel: 40.77 µs
speedup: 1.33x
Shape: 128x8200
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 49.35 µs
new kernel: 35.33 µs
speedup: 1.40x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 60.89 µs
new kernel: 41.46 µs
speedup: 1.47x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 61.75 µs
new kernel: 41.75 µs
speedup: 1.48x
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
1. Compared with old kernel, diff=0
2. Benchmark speed
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: No
- **Did you add or update any necessary documentation?**: No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
No <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->
## Additional Information
Bug [5612406]
---------
Signed-off-by: mxin <mxin@nvidia.com>
What does this PR do?
Type of change: Bug fix
Overview:
H100:
Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.32 µs new kernel: 38.49 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.48 µs new kernel: 44.78 µs speedup: 0.97x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.25 µs new kernel: 43.69 µs speedup: 0.99x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 36.03 µs new kernel: 38.17 µs speedup: 0.94x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 44.24 µs new kernel: 43.78 µs speedup: 1.01x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 43.61 µs speedup: 1.00x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 87.02 µs new kernel: 80.88 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 116.12 µs new kernel: 65.80 µs speedup: 1.76x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 114.39 µs new kernel: 65.30 µs speedup: 1.75x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 237.29 µs new kernel: 219.42 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 349.76 µs new kernel: 138.66 µs speedup: 2.52x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 341.89 µs new kernel: 136.91 µs speedup: 2.50x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 338.65 µs new kernel: 312.70 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 505.63 µs new kernel: 188.24 µs speedup: 2.69x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 492.97 µs new kernel: 186.88 µs speedup: 2.64x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 490.25 µs new kernel: 451.16 µs speedup: 1.09x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 736.04 µs new kernel: 261.94 µs speedup: 2.81x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 717.64 µs new kernel: 257.82 µs speedup: 2.78x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.61 µs new kernel: 38.23 µs speedup: 0.93x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.00 µs new kernel: 43.85 µs speedup: 0.98x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 42.83 µs new kernel: 44.13 µs speedup: 0.97x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 38.12 µs new kernel: 41.28 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.80 µs new kernel: 45.96 µs speedup: 1.15x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 51.56 µs new kernel: 45.30 µs speedup: 1.14x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 48.03 µs new kernel: 38.38 µs speedup: 1.25x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.54 µs new kernel: 44.51 µs speedup: 1.36x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 60.08 µs new kernel: 43.59 µs speedup: 1.38xB200:
Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 34.63 µs new kernel: 32.80 µs speedup: 1.06x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 42.26 µs new kernel: 40.92 µs speedup: 1.03x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 41.38 µs new kernel: 39.30 µs speedup: 1.05x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.07 µs new kernel: 33.93 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.57 µs new kernel: 39.55 µs speedup: 1.10x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.72 µs new kernel: 38.96 µs speedup: 1.12x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 71.64 µs new kernel: 58.66 µs speedup: 1.22x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 81.67 µs new kernel: 57.98 µs speedup: 1.41x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 82.19 µs new kernel: 57.56 µs speedup: 1.43x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 176.85 µs new kernel: 135.78 µs speedup: 1.30x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 217.99 µs new kernel: 121.84 µs speedup: 1.79x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 215.47 µs new kernel: 117.41 µs speedup: 1.84x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 248.18 µs new kernel: 186.64 µs speedup: 1.33x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 306.25 µs new kernel: 163.28 µs speedup: 1.88x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 303.06 µs new kernel: 157.59 µs speedup: 1.92x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 354.23 µs new kernel: 262.99 µs speedup: 1.35x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 439.44 µs new kernel: 224.71 µs speedup: 1.96x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 434.23 µs new kernel: 217.62 µs speedup: 2.00x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.90 µs new kernel: 34.88 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 41.49 µs speedup: 1.05x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.22 µs new kernel: 41.79 µs speedup: 1.03x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 37.37 µs new kernel: 37.84 µs speedup: 0.99x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 49.69 µs new kernel: 43.85 µs speedup: 1.13x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 48.93 µs new kernel: 44.31 µs speedup: 1.10x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.83 µs new kernel: 35.44 µs speedup: 1.18x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 53.23 µs new kernel: 40.64 µs speedup: 1.31x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 54.39 µs new kernel: 40.77 µs speedup: 1.33x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 49.35 µs new kernel: 35.33 µs speedup: 1.40x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.89 µs new kernel: 41.46 µs speedup: 1.47x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 61.75 µs new kernel: 41.75 µs speedup: 1.48xTesting
Before your PR is "Ready for review"
Additional Information
Bug [5612406]