Skip to content

Conversation

@mxinO
Copy link
Contributor

@mxinO mxinO commented Nov 11, 2025

What does this PR do?

Type of change: Bug fix

Overview:

  1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal memory access in rare cases.
  2. Now the tile rows and columns can be specified separately.
  3. Moving data type cast to kernel to save memory for bf16/fp16 inputs.
  4. I did a benchmark comparing with the old kernel on H100 and B200, it has significant speed-up for medium and large size inputs (B200: 1.4x - 2x, H100: 1.7x - 2.8x)

H100:

Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.32 µs                                                                                                                                                                     
    new kernel: 38.49 µs                                                                                                                                                                     
    speedup: 0.92x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.48 µs                                                                                                                                                                     
    new kernel: 44.78 µs                                                                                                                                                                     
    speedup: 0.97x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.25 µs                                                                                                                                                                     
    new kernel: 43.69 µs                                                                                                                                                                     
    speedup: 0.99x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 36.03 µs                                                                                                                                                                     
    new kernel: 38.17 µs                                                                                                                                                                     
    speedup: 0.94x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 44.24 µs                                                                                                                                                                     
    new kernel: 43.78 µs                                                                                                                                                                     
    speedup: 1.01x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.77 µs                                                                                                                                                                     
    new kernel: 43.61 µs                                                                                                                                                                     
    speedup: 1.00x                                                                                                                                                                           

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 87.02 µs
    new kernel: 80.88 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 116.12 µs
    new kernel: 65.80 µs
    speedup: 1.76x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 114.39 µs
    new kernel: 65.30 µs
    speedup: 1.75x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 237.29 µs
    new kernel: 219.42 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 349.76 µs
    new kernel: 138.66 µs
    speedup: 2.52x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 341.89 µs
    new kernel: 136.91 µs
    speedup: 2.50x

Shape: 8192x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 338.65 µs
    new kernel: 312.70 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 505.63 µs
    new kernel: 188.24 µs
    speedup: 2.69x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 492.97 µs
    new kernel: 186.88 µs
    speedup: 2.64x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 490.25 µs
    new kernel: 451.16 µs
    speedup: 1.09x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 736.04 µs
    new kernel: 261.94 µs
    speedup: 2.81x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 717.64 µs
    new kernel: 257.82 µs
    speedup: 2.78x

Shape: 32x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 35.61 µs
    new kernel: 38.23 µs
    speedup: 0.93x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.00 µs
    new kernel: 43.85 µs
    speedup: 0.98x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 42.83 µs
    new kernel: 44.13 µs
    speedup: 0.97x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 38.12 µs
    new kernel: 41.28 µs
    speedup: 0.92x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.80 µs
    new kernel: 45.96 µs
    speedup: 1.15x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 51.56 µs
    new kernel: 45.30 µs
    speedup: 1.14x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 48.03 µs
    new kernel: 38.38 µs
    speedup: 1.25x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.54 µs
    new kernel: 44.51 µs
    speedup: 1.36x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 60.08 µs
    new kernel: 43.59 µs
    speedup: 1.38x

B200:

Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 34.63 µs                                                                                                                                                                     
    new kernel: 32.80 µs                                                                                                                                                                     
    speedup: 1.06x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 42.26 µs                                                                                                                                                                     
    new kernel: 40.92 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 41.38 µs                                                                                                                                                                     
    new kernel: 39.30 µs                                                                                                                                                                     
    speedup: 1.05x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.07 µs                                                                                                                                                                     
    new kernel: 33.93 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.57 µs                                                                                                                                                                     
    new kernel: 39.55 µs                                                                                                                                                                     
    speedup: 1.10x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.72 µs
    new kernel: 38.96 µs
    speedup: 1.12x

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 71.64 µs
    new kernel: 58.66 µs
    speedup: 1.22x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 81.67 µs
    new kernel: 57.98 µs
    speedup: 1.41x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 82.19 µs
    new kernel: 57.56 µs
    speedup: 1.43x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 176.85 µs
    new kernel: 135.78 µs
    speedup: 1.30x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 217.99 µs
    new kernel: 121.84 µs
    speedup: 1.79x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 215.47 µs
    new kernel: 117.41 µs
    speedup: 1.84x

Shape: 8192x12288                                                                                                                                                     
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 248.18 µs
    new kernel: 186.64 µs
    speedup: 1.33x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 306.25 µs
    new kernel: 163.28 µs
    speedup: 1.88x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 303.06 µs
    new kernel: 157.59 µs
    speedup: 1.92x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 354.23 µs
    new kernel: 262.99 µs
    speedup: 1.35x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 439.44 µs
    new kernel: 224.71 µs
    speedup: 1.96x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 434.23 µs
    new kernel: 217.62 µs
    speedup: 2.00x

Shape: 32x4096                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.90 µs                                                                                                                                                                     
    new kernel: 34.88 µs
    speedup: 1.03x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.77 µs
    new kernel: 41.49 µs
    speedup: 1.05x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 43.22 µs
    new kernel: 41.79 µs
    speedup: 1.03x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 37.37 µs
    new kernel: 37.84 µs
    speedup: 0.99x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 49.69 µs
    new kernel: 43.85 µs
    speedup: 1.13x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 48.93 µs
    new kernel: 44.31 µs
    speedup: 1.10x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.83 µs
    new kernel: 35.44 µs
    speedup: 1.18x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 53.23 µs
    new kernel: 40.64 µs
    speedup: 1.31x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 54.39 µs
    new kernel: 40.77 µs
    speedup: 1.33x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 49.35 µs
    new kernel: 35.33 µs
    speedup: 1.40x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.89 µs
    new kernel: 41.46 µs
    speedup: 1.47x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 61.75 µs
    new kernel: 41.75 µs
    speedup: 1.48x

Testing

  1. Compared with old kernel, diff=0
  2. Benchmark speed

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Additional Information

Bug [5612406]

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 11, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Signed-off-by: mxin <mxin@nvidia.com>
@mxinO mxinO force-pushed the mxin/fp4-kernel-improve branch from f58e420 to 0cf5fb6 Compare November 11, 2025 06:25
@mxinO mxinO self-assigned this Nov 11, 2025
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.57%. Comparing base (7a36ccc) to head (1afadfc).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #533   +/-   ##
=======================================
  Coverage   74.57%   74.57%           
=======================================
  Files         183      183           
  Lines       18412    18412           
=======================================
  Hits        13730    13730           
  Misses       4682     4682           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: mxin <mxin@nvidia.com>
@mxinO mxinO changed the title Improve NVFP4 Triton kernel Optimize NVFP4 Triton kernel Nov 13, 2025
@mxinO mxinO requested review from RalphMao and realAsma November 18, 2025 01:33
@mxinO mxinO marked this pull request as ready for review November 18, 2025 01:34
@mxinO mxinO requested a review from a team as a code owner November 18, 2025 01:34
@mxinO mxinO requested a review from cjluo-nv November 18, 2025 01:34
@cjluo-nv
Copy link
Collaborator

Thanks @mxinO. Do you have unittest cover this change?

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)

# Load input data
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, it looks like the old version has the proper mask in tl.load and tl.store. Why does it cause the nvbug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the illegal memory is hard to debug, because the error message never directs to the correct position. I didn't find the root cause actually, just guess it was the addressing issue. So changed the way to load and it's fixed. That bug is a rare case, It's never seen before.

@mxinO
Copy link
Contributor Author

mxinO commented Nov 19, 2025

Thanks @mxinO. Do you have unittest cover this change?

We have tests covering the triton kernel's correctness.

@mxinO mxinO enabled auto-merge (squash) November 25, 2025 01:26
@mxinO mxinO merged commit f06c3f9 into main Nov 25, 2025
27 checks passed
@mxinO mxinO deleted the mxin/fp4-kernel-improve branch November 25, 2025 03:12
jQizhang pushed a commit to jQizhang/TensorRT-Model-Optimizer that referenced this pull request Nov 26, 2025
## What does this PR do?

**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

**Overview:** 

1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal
memory access in rare cases.
2. Now the tile rows and columns can be specified separately.
3. Moving data type cast to kernel to save memory for bf16/fp16 inputs.
4. I did a benchmark comparing with the old kernel on H100 and B200, it
has significant speed-up for medium and large size inputs (B200: 1.4x -
2x, H100: 1.7x - 2.8x)

H100:
```shell
Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.32 µs                                                                                                                                                                     
    new kernel: 38.49 µs                                                                                                                                                                     
    speedup: 0.92x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.48 µs                                                                                                                                                                     
    new kernel: 44.78 µs                                                                                                                                                                     
    speedup: 0.97x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.25 µs                                                                                                                                                                     
    new kernel: 43.69 µs                                                                                                                                                                     
    speedup: 0.99x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 36.03 µs                                                                                                                                                                     
    new kernel: 38.17 µs                                                                                                                                                                     
    speedup: 0.94x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 44.24 µs                                                                                                                                                                     
    new kernel: 43.78 µs                                                                                                                                                                     
    speedup: 1.01x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.77 µs                                                                                                                                                                     
    new kernel: 43.61 µs                                                                                                                                                                     
    speedup: 1.00x                                                                                                                                                                           

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 87.02 µs
    new kernel: 80.88 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 116.12 µs
    new kernel: 65.80 µs
    speedup: 1.76x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 114.39 µs
    new kernel: 65.30 µs
    speedup: 1.75x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 237.29 µs
    new kernel: 219.42 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 349.76 µs
    new kernel: 138.66 µs
    speedup: 2.52x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 341.89 µs
    new kernel: 136.91 µs
    speedup: 2.50x

Shape: 8192x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 338.65 µs
    new kernel: 312.70 µs
    speedup: 1.08x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 505.63 µs
    new kernel: 188.24 µs
    speedup: 2.69x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 492.97 µs
    new kernel: 186.88 µs
    speedup: 2.64x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 490.25 µs
    new kernel: 451.16 µs
    speedup: 1.09x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 736.04 µs
    new kernel: 261.94 µs
    speedup: 2.81x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 717.64 µs
    new kernel: 257.82 µs
    speedup: 2.78x

Shape: 32x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 35.61 µs
    new kernel: 38.23 µs
    speedup: 0.93x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.00 µs
    new kernel: 43.85 µs
    speedup: 0.98x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 42.83 µs
    new kernel: 44.13 µs
    speedup: 0.97x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 38.12 µs
    new kernel: 41.28 µs
    speedup: 0.92x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.80 µs
    new kernel: 45.96 µs
    speedup: 1.15x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 51.56 µs
    new kernel: 45.30 µs
    speedup: 1.14x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.70 µs
    new kernel: 38.03 µs
    speedup: 1.10x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 52.95 µs
    new kernel: 44.14 µs
    speedup: 1.20x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 52.57 µs
    new kernel: 44.38 µs
    speedup: 1.18x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 48.03 µs
    new kernel: 38.38 µs
    speedup: 1.25x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.54 µs
    new kernel: 44.51 µs
    speedup: 1.36x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 60.08 µs
    new kernel: 43.59 µs
    speedup: 1.38x
```

B200:
```shell
Shape: 512x512                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 34.63 µs                                                                                                                                                                     
    new kernel: 32.80 µs                                                                                                                                                                     
    speedup: 1.06x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 42.26 µs                                                                                                                                                                     
    new kernel: 40.92 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 41.38 µs                                                                                                                                                                     
    new kernel: 39.30 µs                                                                                                                                                                     
    speedup: 1.05x                                                                                                                                                                           
                                                                                                                                                                                             
Shape: 1024x1024                                                                                                                                                                             
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.07 µs                                                                                                                                                                     
    new kernel: 33.93 µs                                                                                                                                                                     
    speedup: 1.03x                                                                                                                                                                           
  dtype: torch.bfloat16                                                                                                                                                                      
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.57 µs                                                                                                                                                                     
    new kernel: 39.55 µs                                                                                                                                                                     
    speedup: 1.10x                                                                                                                                                                           
  dtype: torch.float16                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 43.72 µs
    new kernel: 38.96 µs
    speedup: 1.12x

Shape: 4096x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 71.64 µs
    new kernel: 58.66 µs
    speedup: 1.22x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 81.67 µs
    new kernel: 57.98 µs
    speedup: 1.41x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 82.19 µs
    new kernel: 57.56 µs
    speedup: 1.43x

Shape: 8192x8192
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 176.85 µs
    new kernel: 135.78 µs
    speedup: 1.30x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 217.99 µs
    new kernel: 121.84 µs
    speedup: 1.79x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 215.47 µs
    new kernel: 117.41 µs
    speedup: 1.84x

Shape: 8192x12288                                                                                                                                                     
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 248.18 µs
    new kernel: 186.64 µs
    speedup: 1.33x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 306.25 µs
    new kernel: 163.28 µs
    speedup: 1.88x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 303.06 µs
    new kernel: 157.59 µs
    speedup: 1.92x

Shape: 12288x12288
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 354.23 µs
    new kernel: 262.99 µs
    speedup: 1.35x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 439.44 µs
    new kernel: 224.71 µs
    speedup: 1.96x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 434.23 µs
    new kernel: 217.62 µs
    speedup: 2.00x

Shape: 32x4096                                                                                                                                                                               
  dtype: torch.float32                                                                                                                                                                       
    max abs diff: 0.000e+00                                                                                                                                                                  
    old kernel: 35.90 µs                                                                                                                                                                     
    new kernel: 34.88 µs
    speedup: 1.03x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 43.77 µs
    new kernel: 41.49 µs
    speedup: 1.05x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 43.22 µs
    new kernel: 41.79 µs
    speedup: 1.03x

Shape: 1024x4096
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 37.37 µs
    new kernel: 37.84 µs
    speedup: 0.99x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 49.69 µs
    new kernel: 43.85 µs
    speedup: 1.13x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 48.93 µs
    new kernel: 44.31 µs
    speedup: 1.10x

Shape: 32x5000
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 41.83 µs
    new kernel: 35.44 µs
    speedup: 1.18x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 53.23 µs
    new kernel: 40.64 µs
    speedup: 1.31x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 54.39 µs
    new kernel: 40.77 µs
    speedup: 1.33x

Shape: 128x8200
  dtype: torch.float32
    max abs diff: 0.000e+00
    old kernel: 49.35 µs
    new kernel: 35.33 µs
    speedup: 1.40x
  dtype: torch.bfloat16
    max abs diff: 0.000e+00
    old kernel: 60.89 µs
    new kernel: 41.46 µs
    speedup: 1.47x
  dtype: torch.float16
    max abs diff: 0.000e+00
    old kernel: 61.75 µs
    new kernel: 41.75 µs
    speedup: 1.48x
```

## Testing
<!-- Mention how have you tested your change if applicable. -->
1. Compared with old kernel, diff=0
2. Benchmark speed
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: No
- **Did you add or update any necessary documentation?**: No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
No <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->

## Additional Information
Bug [5612406]

---------

Signed-off-by: mxin <mxin@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants