Commit fd8db7b
Optimize NVFP4 Triton kernel (NVIDIA#533)
## What does this PR do?
**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
**Overview:**
1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal
memory access in rare cases.
2. Now the tile rows and columns can be specified separately.
3. Moving data type cast to kernel to save memory for bf16/fp16 inputs.
4. I did a benchmark comparing with the old kernel on H100 and B200, it
has significant speed-up for medium and large size inputs (B200: 1.4x -
2x, H100: 1.7x - 2.8x)
H100:
```shell
Shape: 512x512
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.32 µs
new kernel: 38.49 µs
speedup: 0.92x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.48 µs
new kernel: 44.78 µs
speedup: 0.97x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.25 µs
new kernel: 43.69 µs
speedup: 0.99x
Shape: 1024x1024
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 36.03 µs
new kernel: 38.17 µs
speedup: 0.94x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 44.24 µs
new kernel: 43.78 µs
speedup: 1.01x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.77 µs
new kernel: 43.61 µs
speedup: 1.00x
Shape: 4096x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 87.02 µs
new kernel: 80.88 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 116.12 µs
new kernel: 65.80 µs
speedup: 1.76x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 114.39 µs
new kernel: 65.30 µs
speedup: 1.75x
Shape: 8192x8192
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 237.29 µs
new kernel: 219.42 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 349.76 µs
new kernel: 138.66 µs
speedup: 2.52x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 341.89 µs
new kernel: 136.91 µs
speedup: 2.50x
Shape: 8192x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 338.65 µs
new kernel: 312.70 µs
speedup: 1.08x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 505.63 µs
new kernel: 188.24 µs
speedup: 2.69x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 492.97 µs
new kernel: 186.88 µs
speedup: 2.64x
Shape: 12288x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 490.25 µs
new kernel: 451.16 µs
speedup: 1.09x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 736.04 µs
new kernel: 261.94 µs
speedup: 2.81x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 717.64 µs
new kernel: 257.82 µs
speedup: 2.78x
Shape: 32x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.61 µs
new kernel: 38.23 µs
speedup: 0.93x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.00 µs
new kernel: 43.85 µs
speedup: 0.98x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 42.83 µs
new kernel: 44.13 µs
speedup: 0.97x
Shape: 1024x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 38.12 µs
new kernel: 41.28 µs
speedup: 0.92x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.80 µs
new kernel: 45.96 µs
speedup: 1.15x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 51.56 µs
new kernel: 45.30 µs
speedup: 1.14x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.70 µs
new kernel: 38.03 µs
speedup: 1.10x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.95 µs
new kernel: 44.14 µs
speedup: 1.20x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 52.57 µs
new kernel: 44.38 µs
speedup: 1.18x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.70 µs
new kernel: 38.03 µs
speedup: 1.10x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 52.95 µs
new kernel: 44.14 µs
speedup: 1.20x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 52.57 µs
new kernel: 44.38 µs
speedup: 1.18x
Shape: 128x8200
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 48.03 µs
new kernel: 38.38 µs
speedup: 1.25x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 60.54 µs
new kernel: 44.51 µs
speedup: 1.36x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 60.08 µs
new kernel: 43.59 µs
speedup: 1.38x
```
B200:
```shell
Shape: 512x512
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 34.63 µs
new kernel: 32.80 µs
speedup: 1.06x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 42.26 µs
new kernel: 40.92 µs
speedup: 1.03x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 41.38 µs
new kernel: 39.30 µs
speedup: 1.05x
Shape: 1024x1024
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.07 µs
new kernel: 33.93 µs
speedup: 1.03x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.57 µs
new kernel: 39.55 µs
speedup: 1.10x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.72 µs
new kernel: 38.96 µs
speedup: 1.12x
Shape: 4096x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 71.64 µs
new kernel: 58.66 µs
speedup: 1.22x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 81.67 µs
new kernel: 57.98 µs
speedup: 1.41x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 82.19 µs
new kernel: 57.56 µs
speedup: 1.43x
Shape: 8192x8192
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 176.85 µs
new kernel: 135.78 µs
speedup: 1.30x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 217.99 µs
new kernel: 121.84 µs
speedup: 1.79x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 215.47 µs
new kernel: 117.41 µs
speedup: 1.84x
Shape: 8192x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 248.18 µs
new kernel: 186.64 µs
speedup: 1.33x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 306.25 µs
new kernel: 163.28 µs
speedup: 1.88x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 303.06 µs
new kernel: 157.59 µs
speedup: 1.92x
Shape: 12288x12288
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 354.23 µs
new kernel: 262.99 µs
speedup: 1.35x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 439.44 µs
new kernel: 224.71 µs
speedup: 1.96x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 434.23 µs
new kernel: 217.62 µs
speedup: 2.00x
Shape: 32x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 35.90 µs
new kernel: 34.88 µs
speedup: 1.03x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 43.77 µs
new kernel: 41.49 µs
speedup: 1.05x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 43.22 µs
new kernel: 41.79 µs
speedup: 1.03x
Shape: 1024x4096
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 37.37 µs
new kernel: 37.84 µs
speedup: 0.99x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 49.69 µs
new kernel: 43.85 µs
speedup: 1.13x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 48.93 µs
new kernel: 44.31 µs
speedup: 1.10x
Shape: 32x5000
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 41.83 µs
new kernel: 35.44 µs
speedup: 1.18x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 53.23 µs
new kernel: 40.64 µs
speedup: 1.31x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 54.39 µs
new kernel: 40.77 µs
speedup: 1.33x
Shape: 128x8200
dtype: torch.float32
max abs diff: 0.000e+00
old kernel: 49.35 µs
new kernel: 35.33 µs
speedup: 1.40x
dtype: torch.bfloat16
max abs diff: 0.000e+00
old kernel: 60.89 µs
new kernel: 41.46 µs
speedup: 1.47x
dtype: torch.float16
max abs diff: 0.000e+00
old kernel: 61.75 µs
new kernel: 41.75 µs
speedup: 1.48x
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
1. Compared with old kernel, diff=0
2. Benchmark speed
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: No
- **Did you add or update any necessary documentation?**: No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
No <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->
## Additional Information
Bug [5612406]
---------
Signed-off-by: mxin <mxin@nvidia.com>1 parent b2f7c4f commit fd8db7b
1 file changed
+107
-63
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
30 | 45 | | |
31 | 46 | | |
32 | 47 | | |
33 | 48 | | |
34 | 49 | | |
35 | 50 | | |
36 | 51 | | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
37 | 56 | | |
38 | | - | |
| 57 | + | |
| 58 | + | |
39 | 59 | | |
| 60 | + | |
40 | 61 | | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
| 62 | + | |
53 | 63 | | |
54 | 64 | | |
55 | 65 | | |
56 | | - | |
57 | | - | |
| 66 | + | |
| 67 | + | |
58 | 68 | | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
64 | 88 | | |
65 | | - | |
66 | | - | |
| 89 | + | |
67 | 90 | | |
68 | | - | |
69 | | - | |
70 | | - | |
| 91 | + | |
| 92 | + | |
71 | 93 | | |
72 | | - | |
73 | 94 | | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | 95 | | |
80 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
81 | 101 | | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
| 102 | + | |
87 | 103 | | |
| 104 | + | |
88 | 105 | | |
89 | 106 | | |
90 | | - | |
91 | 107 | | |
92 | 108 | | |
93 | 109 | | |
| |||
103 | 119 | | |
104 | 120 | | |
105 | 121 | | |
106 | | - | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
107 | 127 | | |
108 | 128 | | |
109 | 129 | | |
110 | 130 | | |
111 | 131 | | |
112 | 132 | | |
113 | | - | |
114 | 133 | | |
115 | | - | |
| 134 | + | |
116 | 135 | | |
117 | | - | |
118 | | - | |
119 | | - | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
120 | 139 | | |
121 | 140 | | |
122 | 141 | | |
123 | 142 | | |
124 | 143 | | |
125 | 144 | | |
126 | | - | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
127 | 149 | | |
128 | | - | |
| 150 | + | |
129 | 151 | | |
130 | 152 | | |
131 | | - | |
132 | | - | |
133 | | - | |
134 | | - | |
135 | | - | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
136 | 161 | | |
137 | 162 | | |
138 | | - | |
| 163 | + | |
139 | 164 | | |
140 | 165 | | |
141 | 166 | | |
142 | 167 | | |
143 | 168 | | |
144 | | - | |
145 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
146 | 178 | | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | 179 | | |
152 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
153 | 194 | | |
154 | 195 | | |
155 | 196 | | |
156 | 197 | | |
157 | 198 | | |
158 | 199 | | |
159 | | - | |
160 | | - | |
161 | | - | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
162 | 205 | | |
163 | | - | |
| 206 | + | |
| 207 | + | |
164 | 208 | | |
165 | 209 | | |
166 | 210 | | |
| |||
0 commit comments