Skip to content

Commit fd8db7b

Browse files
mxinOjQizhang
authored andcommitted
Optimize NVFP4 Triton kernel (NVIDIA#533)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** 1. Use mak_block_ptr for loading blocks, now it's more safe, fix illegal memory access in rare cases. 2. Now the tile rows and columns can be specified separately. 3. Moving data type cast to kernel to save memory for bf16/fp16 inputs. 4. I did a benchmark comparing with the old kernel on H100 and B200, it has significant speed-up for medium and large size inputs (B200: 1.4x - 2x, H100: 1.7x - 2.8x) H100: ```shell Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.32 µs new kernel: 38.49 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.48 µs new kernel: 44.78 µs speedup: 0.97x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.25 µs new kernel: 43.69 µs speedup: 0.99x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 36.03 µs new kernel: 38.17 µs speedup: 0.94x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 44.24 µs new kernel: 43.78 µs speedup: 1.01x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 43.61 µs speedup: 1.00x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 87.02 µs new kernel: 80.88 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 116.12 µs new kernel: 65.80 µs speedup: 1.76x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 114.39 µs new kernel: 65.30 µs speedup: 1.75x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 237.29 µs new kernel: 219.42 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 349.76 µs new kernel: 138.66 µs speedup: 2.52x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 341.89 µs new kernel: 136.91 µs speedup: 2.50x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 338.65 µs new kernel: 312.70 µs speedup: 1.08x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 505.63 µs new kernel: 188.24 µs speedup: 2.69x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 492.97 µs new kernel: 186.88 µs speedup: 2.64x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 490.25 µs new kernel: 451.16 µs speedup: 1.09x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 736.04 µs new kernel: 261.94 µs speedup: 2.81x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 717.64 µs new kernel: 257.82 µs speedup: 2.78x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.61 µs new kernel: 38.23 µs speedup: 0.93x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.00 µs new kernel: 43.85 µs speedup: 0.98x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 42.83 µs new kernel: 44.13 µs speedup: 0.97x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 38.12 µs new kernel: 41.28 µs speedup: 0.92x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.80 µs new kernel: 45.96 µs speedup: 1.15x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 51.56 µs new kernel: 45.30 µs speedup: 1.14x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.70 µs new kernel: 38.03 µs speedup: 1.10x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 52.95 µs new kernel: 44.14 µs speedup: 1.20x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 52.57 µs new kernel: 44.38 µs speedup: 1.18x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 48.03 µs new kernel: 38.38 µs speedup: 1.25x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.54 µs new kernel: 44.51 µs speedup: 1.36x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 60.08 µs new kernel: 43.59 µs speedup: 1.38x ``` B200: ```shell Shape: 512x512 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 34.63 µs new kernel: 32.80 µs speedup: 1.06x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 42.26 µs new kernel: 40.92 µs speedup: 1.03x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 41.38 µs new kernel: 39.30 µs speedup: 1.05x Shape: 1024x1024 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.07 µs new kernel: 33.93 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.57 µs new kernel: 39.55 µs speedup: 1.10x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.72 µs new kernel: 38.96 µs speedup: 1.12x Shape: 4096x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 71.64 µs new kernel: 58.66 µs speedup: 1.22x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 81.67 µs new kernel: 57.98 µs speedup: 1.41x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 82.19 µs new kernel: 57.56 µs speedup: 1.43x Shape: 8192x8192 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 176.85 µs new kernel: 135.78 µs speedup: 1.30x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 217.99 µs new kernel: 121.84 µs speedup: 1.79x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 215.47 µs new kernel: 117.41 µs speedup: 1.84x Shape: 8192x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 248.18 µs new kernel: 186.64 µs speedup: 1.33x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 306.25 µs new kernel: 163.28 µs speedup: 1.88x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 303.06 µs new kernel: 157.59 µs speedup: 1.92x Shape: 12288x12288 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 354.23 µs new kernel: 262.99 µs speedup: 1.35x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 439.44 µs new kernel: 224.71 µs speedup: 1.96x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 434.23 µs new kernel: 217.62 µs speedup: 2.00x Shape: 32x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 35.90 µs new kernel: 34.88 µs speedup: 1.03x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 43.77 µs new kernel: 41.49 µs speedup: 1.05x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 43.22 µs new kernel: 41.79 µs speedup: 1.03x Shape: 1024x4096 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 37.37 µs new kernel: 37.84 µs speedup: 0.99x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 49.69 µs new kernel: 43.85 µs speedup: 1.13x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 48.93 µs new kernel: 44.31 µs speedup: 1.10x Shape: 32x5000 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 41.83 µs new kernel: 35.44 µs speedup: 1.18x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 53.23 µs new kernel: 40.64 µs speedup: 1.31x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 54.39 µs new kernel: 40.77 µs speedup: 1.33x Shape: 128x8200 dtype: torch.float32 max abs diff: 0.000e+00 old kernel: 49.35 µs new kernel: 35.33 µs speedup: 1.40x dtype: torch.bfloat16 max abs diff: 0.000e+00 old kernel: 60.89 µs new kernel: 41.46 µs speedup: 1.47x dtype: torch.float16 max abs diff: 0.000e+00 old kernel: 61.75 µs new kernel: 41.75 µs speedup: 1.48x ``` ## Testing <!-- Mention how have you tested your change if applicable. --> 1. Compared with old kernel, diff=0 2. Benchmark speed ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information Bug [5612406] --------- Signed-off-by: mxin <mxin@nvidia.com>
1 parent b2f7c4f commit fd8db7b

File tree

1 file changed

+107
-63
lines changed

1 file changed

+107
-63
lines changed

modelopt/torch/quantization/triton/fp4_kernel.py

Lines changed: 107 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,67 +27,83 @@
2727
__all__ = ["fp4_fake_quant_block"]
2828

2929

30+
_TORCH_TO_TL_DTYPE = {
31+
torch.float32: tl.float32,
32+
torch.float: tl.float32,
33+
torch.float16: tl.float16,
34+
torch.half: tl.float16,
35+
torch.bfloat16: tl.bfloat16,
36+
}
37+
38+
39+
def _torch_dtype_to_tl(dtype: torch.dtype):
40+
if dtype not in _TORCH_TO_TL_DTYPE:
41+
raise ValueError(f"Unsupported dtype for fp4 fake quantization: {dtype}")
42+
return _TORCH_TO_TL_DTYPE[dtype]
43+
44+
3045
@triton.jit
3146
def fp4_fake_quant_kernel(
3247
x_ptr,
3348
y_ptr,
3449
M,
3550
N,
3651
global_scale_ptr,
52+
stride_xm,
53+
stride_xn,
54+
stride_ym,
55+
stride_yn,
3756
BLOCK_SIZE: tl.constexpr,
38-
TILE_SIZE: tl.constexpr,
57+
TILE_M: tl.constexpr,
58+
TILE_N: tl.constexpr,
3959
NUM_FP4_BLOCKS: tl.constexpr,
60+
OUT_DTYPE: tl.constexpr,
4061
):
41-
"""Applies FP4 fake quantization on input data using per-block scaling factors.
42-
43-
Args:
44-
x_ptr (tl.pointer): Pointer to the input tensor (BF16/FP32)
45-
y_ptr (tl.pointer): Pointer to the output buffer
46-
M (int): Number of rows in the matrix
47-
N (int): Number of columns in the matrix
48-
global_scale_ptr (tl.pointer): Pointer to the global scaling factor tensor
49-
BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block
50-
TILE_SIZE (tl.constexpr): Size of the processing block
51-
NUM_FP4_BLOCKS (tl.constexpr): Number of FP4 blocks within TILE_SIZE
52-
"""
62+
"""Applies FP4 fake quantization using block pointers for memory addressing."""
5363
pid_m = tl.program_id(axis=0)
5464
pid_n = tl.program_id(axis=1)
5565

56-
# Load global scale from tensor
57-
global_scale = tl.load(global_scale_ptr).to(tl.float32)
66+
row_start = pid_m * TILE_M
67+
col_start = pid_n * TILE_N
5868

59-
# Calculate offsets
60-
offs_m = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
61-
offs_n = pid_n * TILE_SIZE + tl.arange(0, TILE_SIZE)
62-
offs = offs_m[:, None] * N + offs_n[None, :]
63-
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
69+
x_block_ptr = tl.make_block_ptr(
70+
base=x_ptr,
71+
shape=(M, N),
72+
strides=(stride_xm, stride_xn),
73+
offsets=(row_start, col_start),
74+
block_shape=(TILE_M, TILE_N),
75+
order=(1, 0),
76+
)
77+
y_block_ptr = tl.make_block_ptr(
78+
base=y_ptr,
79+
shape=(M, N),
80+
strides=(stride_ym, stride_yn),
81+
offsets=(row_start, col_start),
82+
block_shape=(TILE_M, TILE_N),
83+
order=(1, 0),
84+
)
85+
86+
global_scale = tl.load(global_scale_ptr).to(tl.float32)
87+
global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)
6488

65-
# Load input data
66-
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
89+
tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
6790

68-
# Reshape for block processing
69-
x_reshaped = tl.reshape(x, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE))
70-
x_abs = tl.abs(x_reshaped)
91+
tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE))
92+
x_abs = tl.abs(tile_reshaped)
7193

72-
# Calculate max values for each FP4 block
7394
block_max = tl.max(x_abs, axis=2, keep_dims=True)
74-
# global_scale = global_amax / (448 * 6)
75-
block_max_quant = (
76-
tl.minimum((block_max / (6.0 * global_scale)), 448.0).to(tl.float8e4nv).to(tl.float32)
77-
* global_scale
78-
)
7995

80-
# Broadcast max values
96+
block_max_scaled = block_max / (6.0 * global_scale_safe)
97+
block_max_scaled = tl.minimum(block_max_scaled, 448.0)
98+
block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale
99+
block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0)
100+
81101
block_max_quant_broadcast = tl.broadcast_to(
82-
block_max_quant, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE)
83-
)
84-
# Set scale to 1 if block amax is 0
85-
block_max_quant_broadcast = tl.where(
86-
block_max_quant_broadcast < 1e-5, 1.0, block_max_quant_broadcast
102+
block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)
87103
)
104+
88105
abs_scaled = x_abs / block_max_quant_broadcast
89106

90-
# Quantize to FP4 values: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, following round to even
91107
q_val = tl.where(
92108
abs_scaled <= 0.25,
93109
0.0,
@@ -103,64 +119,92 @@ def fp4_fake_quant_kernel(
103119
tl.where(
104120
abs_scaled <= 2.5,
105121
2.0,
106-
tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)),
122+
tl.where(
123+
abs_scaled < 3.5,
124+
3.0,
125+
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
126+
),
107127
),
108128
),
109129
),
110130
),
111131
)
112132

113-
# Apply signs and rescale
114133
x_rescaled = q_val * block_max_quant_broadcast
115-
x_rescaled = tl.where(x_reshaped >= 0, x_rescaled, -x_rescaled)
134+
x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled)
116135

117-
# Reshape back and store
118-
x_rescaled = tl.reshape(x_rescaled, (TILE_SIZE, TILE_SIZE))
119-
tl.store(y_ptr + offs, x_rescaled, mask=mask)
136+
tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N))
137+
138+
tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1))
120139

121140

122141
def fp4_fake_quant_block(
123142
x: torch.Tensor,
124143
global_amax: torch.Tensor,
125144
block_size: int = 16,
126-
tile_size: int = 128,
145+
tile_rows: int = 16,
146+
tile_cols: int = 64,
147+
num_warps: int | None = None,
148+
num_stages: int | None = None,
127149
) -> torch.Tensor:
128-
"""Applies FP4 fake quantization on the input tensor.
150+
"""FP4 fake quantization implementation using block-pointer tiling.
129151
130152
Args:
131-
x (torch.Tensor): Input tensor of shape (M, N)
132-
global_amax (torch.Tensor): Global max value of the input tensor
133-
This needs to be a tensor to be cuda-graph compatible
134-
block_size (int): Size of FP4 quantization blocks
135-
tile_size (int): Size of processing blocks
153+
x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher.
154+
global_amax (torch.Tensor): Global maximum value tensor for scaling.
155+
block_size (int): Number of elements per FP4 block.
156+
tile_rows (int, optional): Row tile size. Defaults to 64.
157+
tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to
158+
the nearest multiple of ``block_size`` internally.
159+
num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``.
160+
num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``.
136161
137162
Returns:
138-
torch.Tensor: Quantized tensor of the same shape as input
163+
torch.Tensor: Fake-quantized tensor matching the input shape and dtype.
139164
"""
140165
x_shape = x.shape
141166
x_dtype = x.dtype
142167
x = x.reshape(-1, x_shape[-1]).contiguous()
143168

144-
M, N = x.size()
145-
y = torch.empty_like(x, dtype=torch.get_default_dtype())
169+
M, N = x.shape
170+
y = torch.empty_like(x)
171+
172+
stride_xm, stride_xn = x.stride()
173+
stride_ym, stride_yn = y.stride()
174+
175+
tile_cols = max(tile_cols, block_size)
176+
tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size
177+
num_fp4_blocks = tile_cols_aligned // block_size
146178

147-
grid = lambda meta: (
148-
triton.cdiv(M, meta["TILE_SIZE"]),
149-
triton.cdiv(N, meta["TILE_SIZE"]),
150-
)
151179
global_scale = global_amax.float() / (6.0 * 448.0)
152-
num_fp4_blocks = tile_size // block_size
180+
181+
grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned))
182+
183+
launch_kwargs = {
184+
"BLOCK_SIZE": block_size,
185+
"TILE_M": tile_rows,
186+
"TILE_N": tile_cols_aligned,
187+
"NUM_FP4_BLOCKS": num_fp4_blocks,
188+
"OUT_DTYPE": _torch_dtype_to_tl(x_dtype),
189+
}
190+
if num_warps is not None:
191+
launch_kwargs["num_warps"] = num_warps
192+
if num_stages is not None:
193+
launch_kwargs["num_stages"] = num_stages
153194
fp4_fake_quant_kernel[grid](
154195
x,
155196
y,
156197
M,
157198
N,
158199
global_scale,
159-
TILE_SIZE=tile_size,
160-
BLOCK_SIZE=block_size,
161-
NUM_FP4_BLOCKS=num_fp4_blocks,
200+
stride_xm,
201+
stride_xn,
202+
stride_ym,
203+
stride_yn,
204+
**launch_kwargs,
162205
)
163-
y = y.reshape(x_shape).contiguous().to(dtype=x_dtype)
206+
207+
y = y.view(*x_shape)
164208
return y
165209

166210

0 commit comments

Comments
 (0)