Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 107 additions & 63 deletions modelopt/torch/quantization/triton/fp4_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,67 +27,83 @@
__all__ = ["fp4_fake_quant_block"]


_TORCH_TO_TL_DTYPE = {
torch.float32: tl.float32,
torch.float: tl.float32,
torch.float16: tl.float16,
torch.half: tl.float16,
torch.bfloat16: tl.bfloat16,
}


def _torch_dtype_to_tl(dtype: torch.dtype):
if dtype not in _TORCH_TO_TL_DTYPE:
raise ValueError(f"Unsupported dtype for fp4 fake quantization: {dtype}")
return _TORCH_TO_TL_DTYPE[dtype]


@triton.jit
def fp4_fake_quant_kernel(
x_ptr,
y_ptr,
M,
N,
global_scale_ptr,
stride_xm,
stride_xn,
stride_ym,
stride_yn,
BLOCK_SIZE: tl.constexpr,
TILE_SIZE: tl.constexpr,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
NUM_FP4_BLOCKS: tl.constexpr,
OUT_DTYPE: tl.constexpr,
):
"""Applies FP4 fake quantization on input data using per-block scaling factors.

Args:
x_ptr (tl.pointer): Pointer to the input tensor (BF16/FP32)
y_ptr (tl.pointer): Pointer to the output buffer
M (int): Number of rows in the matrix
N (int): Number of columns in the matrix
global_scale_ptr (tl.pointer): Pointer to the global scaling factor tensor
BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block
TILE_SIZE (tl.constexpr): Size of the processing block
NUM_FP4_BLOCKS (tl.constexpr): Number of FP4 blocks within TILE_SIZE
"""
"""Applies FP4 fake quantization using block pointers for memory addressing."""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)

# Load global scale from tensor
global_scale = tl.load(global_scale_ptr).to(tl.float32)
row_start = pid_m * TILE_M
col_start = pid_n * TILE_N

# Calculate offsets
offs_m = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
offs_n = pid_n * TILE_SIZE + tl.arange(0, TILE_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x_block_ptr = tl.make_block_ptr(
base=x_ptr,
shape=(M, N),
strides=(stride_xm, stride_xn),
offsets=(row_start, col_start),
block_shape=(TILE_M, TILE_N),
order=(1, 0),
)
y_block_ptr = tl.make_block_ptr(
base=y_ptr,
shape=(M, N),
strides=(stride_ym, stride_yn),
offsets=(row_start, col_start),
block_shape=(TILE_M, TILE_N),
order=(1, 0),
)

global_scale = tl.load(global_scale_ptr).to(tl.float32)
global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12)

# Load input data
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, it looks like the old version has the proper mask in tl.load and tl.store. Why does it cause the nvbug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the illegal memory is hard to debug, because the error message never directs to the correct position. I didn't find the root cause actually, just guess it was the addressing issue. So changed the way to load and it's fixed. That bug is a rare case, It's never seen before.

tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32)

# Reshape for block processing
x_reshaped = tl.reshape(x, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE))
x_abs = tl.abs(x_reshaped)
tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE))
x_abs = tl.abs(tile_reshaped)

# Calculate max values for each FP4 block
block_max = tl.max(x_abs, axis=2, keep_dims=True)
# global_scale = global_amax / (448 * 6)
block_max_quant = (
tl.minimum((block_max / (6.0 * global_scale)), 448.0).to(tl.float8e4nv).to(tl.float32)
* global_scale
)

# Broadcast max values
block_max_scaled = block_max / (6.0 * global_scale_safe)
block_max_scaled = tl.minimum(block_max_scaled, 448.0)
block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale
block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0)

block_max_quant_broadcast = tl.broadcast_to(
block_max_quant, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE)
)
# Set scale to 1 if block amax is 0
block_max_quant_broadcast = tl.where(
block_max_quant_broadcast < 1e-5, 1.0, block_max_quant_broadcast
block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)
)

abs_scaled = x_abs / block_max_quant_broadcast

# Quantize to FP4 values: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, following round to even
q_val = tl.where(
abs_scaled <= 0.25,
0.0,
Expand All @@ -103,64 +119,92 @@ def fp4_fake_quant_kernel(
tl.where(
abs_scaled <= 2.5,
2.0,
tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)),
tl.where(
abs_scaled < 3.5,
3.0,
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
),
),
),
),
),
)

# Apply signs and rescale
x_rescaled = q_val * block_max_quant_broadcast
x_rescaled = tl.where(x_reshaped >= 0, x_rescaled, -x_rescaled)
x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled)

# Reshape back and store
x_rescaled = tl.reshape(x_rescaled, (TILE_SIZE, TILE_SIZE))
tl.store(y_ptr + offs, x_rescaled, mask=mask)
tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N))

tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1))


def fp4_fake_quant_block(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
tile_size: int = 128,
tile_rows: int = 16,
tile_cols: int = 64,
num_warps: int | None = None,
num_stages: int | None = None,
) -> torch.Tensor:
"""Applies FP4 fake quantization on the input tensor.
"""FP4 fake quantization implementation using block-pointer tiling.

Args:
x (torch.Tensor): Input tensor of shape (M, N)
global_amax (torch.Tensor): Global max value of the input tensor
This needs to be a tensor to be cuda-graph compatible
block_size (int): Size of FP4 quantization blocks
tile_size (int): Size of processing blocks
x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher.
global_amax (torch.Tensor): Global maximum value tensor for scaling.
block_size (int): Number of elements per FP4 block.
tile_rows (int, optional): Row tile size. Defaults to 64.
tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to
the nearest multiple of ``block_size`` internally.
num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``.
num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``.

Returns:
torch.Tensor: Quantized tensor of the same shape as input
torch.Tensor: Fake-quantized tensor matching the input shape and dtype.
"""
x_shape = x.shape
x_dtype = x.dtype
x = x.reshape(-1, x_shape[-1]).contiguous()

M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
M, N = x.shape
y = torch.empty_like(x)

stride_xm, stride_xn = x.stride()
stride_ym, stride_yn = y.stride()

tile_cols = max(tile_cols, block_size)
tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size
num_fp4_blocks = tile_cols_aligned // block_size

grid = lambda meta: (
triton.cdiv(M, meta["TILE_SIZE"]),
triton.cdiv(N, meta["TILE_SIZE"]),
)
global_scale = global_amax.float() / (6.0 * 448.0)
num_fp4_blocks = tile_size // block_size

grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned))

launch_kwargs = {
"BLOCK_SIZE": block_size,
"TILE_M": tile_rows,
"TILE_N": tile_cols_aligned,
"NUM_FP4_BLOCKS": num_fp4_blocks,
"OUT_DTYPE": _torch_dtype_to_tl(x_dtype),
}
if num_warps is not None:
launch_kwargs["num_warps"] = num_warps
if num_stages is not None:
launch_kwargs["num_stages"] = num_stages
fp4_fake_quant_kernel[grid](
x,
y,
M,
N,
global_scale,
TILE_SIZE=tile_size,
BLOCK_SIZE=block_size,
NUM_FP4_BLOCKS=num_fp4_blocks,
stride_xm,
stride_xn,
stride_ym,
stride_yn,
**launch_kwargs,
)
y = y.reshape(x_shape).contiguous().to(dtype=x_dtype)

y = y.view(*x_shape)
return y


Expand Down