From 0cf5fb60b129eb7c38f83b8e8d536c643a9213f7 Mon Sep 17 00:00:00 2001 From: mxin Date: Mon, 10 Nov 2025 19:08:19 -0800 Subject: [PATCH 1/3] new fp3 kernel Signed-off-by: mxin --- .../torch/quantization/triton/fp4_kernel.py | 151 +++++++++++------- 1 file changed, 89 insertions(+), 62 deletions(-) diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 33b38b7a9..40a6ecc1c 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -34,60 +34,60 @@ def fp4_fake_quant_kernel( M, N, global_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, BLOCK_SIZE: tl.constexpr, - TILE_SIZE: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, NUM_FP4_BLOCKS: tl.constexpr, ): - """Applies FP4 fake quantization on input data using per-block scaling factors. - - Args: - x_ptr (tl.pointer): Pointer to the input tensor (BF16/FP32) - y_ptr (tl.pointer): Pointer to the output buffer - M (int): Number of rows in the matrix - N (int): Number of columns in the matrix - global_scale_ptr (tl.pointer): Pointer to the global scaling factor tensor - BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block - TILE_SIZE (tl.constexpr): Size of the processing block - NUM_FP4_BLOCKS (tl.constexpr): Number of FP4 blocks within TILE_SIZE - """ + """Applies FP4 fake quantization using block pointers for memory addressing.""" pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) - # Load global scale from tensor - global_scale = tl.load(global_scale_ptr).to(tl.float32) + row_start = pid_m * TILE_M + col_start = pid_n * TILE_N - # Calculate offsets - offs_m = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE) - offs_n = pid_n * TILE_SIZE + tl.arange(0, TILE_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x_block_ptr = tl.make_block_ptr( + base=x_ptr, + shape=(M, N), + strides=(stride_xm, stride_xn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + y_block_ptr = tl.make_block_ptr( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + + global_scale = tl.load(global_scale_ptr).to(tl.float32) + global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) - # Load input data - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - # Reshape for block processing - x_reshaped = tl.reshape(x, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE)) - x_abs = tl.abs(x_reshaped) + tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) + x_abs = tl.abs(tile_reshaped) - # Calculate max values for each FP4 block block_max = tl.max(x_abs, axis=2, keep_dims=True) - # global_scale = global_amax / (448 * 6) - block_max_quant = ( - tl.minimum((block_max / (6.0 * global_scale)), 448.0).to(tl.float8e4nv).to(tl.float32) - * global_scale - ) - # Broadcast max values + block_max_scaled = block_max / (6.0 * global_scale_safe) + block_max_scaled = tl.minimum(block_max_scaled, 448.0) + block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale + block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) + block_max_quant_broadcast = tl.broadcast_to( - block_max_quant, (TILE_SIZE, NUM_FP4_BLOCKS, BLOCK_SIZE) - ) - # Set scale to 1 if block amax is 0 - block_max_quant_broadcast = tl.where( - block_max_quant_broadcast < 1e-5, 1.0, block_max_quant_broadcast + block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE) ) + abs_scaled = x_abs / block_max_quant_broadcast - # Quantize to FP4 values: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}, following round to even q_val = tl.where( abs_scaled <= 0.25, 0.0, @@ -103,63 +103,90 @@ def fp4_fake_quant_kernel( tl.where( abs_scaled <= 2.5, 2.0, - tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)), + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), ), ), ), ), ) - # Apply signs and rescale x_rescaled = q_val * block_max_quant_broadcast - x_rescaled = tl.where(x_reshaped >= 0, x_rescaled, -x_rescaled) + x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) + + tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) - # Reshape back and store - x_rescaled = tl.reshape(x_rescaled, (TILE_SIZE, TILE_SIZE)) - tl.store(y_ptr + offs, x_rescaled, mask=mask) + tl.store(y_block_ptr, tile_quant, boundary_check=(0, 1)) def fp4_fake_quant_block( x: torch.Tensor, global_amax: torch.Tensor, block_size: int = 16, - tile_size: int = 128, + tile_rows: int = 16, + tile_cols: int = 64, + num_warps: int | None = None, + num_stages: int | None = None, ) -> torch.Tensor: - """Applies FP4 fake quantization on the input tensor. + """FP4 fake quantization implementation using block-pointer tiling. Args: - x (torch.Tensor): Input tensor of shape (M, N) - global_amax (torch.Tensor): Global max value of the input tensor - This needs to be a tensor to be cuda-graph compatible - block_size (int): Size of FP4 quantization blocks - tile_size (int): Size of processing blocks + x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. + global_amax (torch.Tensor): Global maximum value tensor for scaling. + block_size (int): Number of elements per FP4 block. + tile_rows (int, optional): Row tile size. Defaults to 64. + tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to + the nearest multiple of ``block_size`` internally. + num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``. + num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``. Returns: - torch.Tensor: Quantized tensor of the same shape as input + torch.Tensor: Fake-quantized tensor matching the input shape and dtype. """ x_shape = x.shape x_dtype = x.dtype x = x.reshape(-1, x_shape[-1]).contiguous() - M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) + M, N = x.shape + y = torch.empty_like(x, dtype=torch.float32) + + stride_xm, stride_xn = x.stride() + stride_ym, stride_yn = y.stride() + + tile_cols = max(tile_cols, block_size) + tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size + num_fp4_blocks = tile_cols_aligned // block_size - grid = lambda meta: ( - triton.cdiv(M, meta["TILE_SIZE"]), - triton.cdiv(N, meta["TILE_SIZE"]), - ) global_scale = global_amax.float() / (6.0 * 448.0) - num_fp4_blocks = tile_size // block_size + + grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) + + launch_kwargs = { + "BLOCK_SIZE": block_size, + "TILE_M": tile_rows, + "TILE_N": tile_cols_aligned, + "NUM_FP4_BLOCKS": num_fp4_blocks, + } + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages fp4_fake_quant_kernel[grid]( x, y, M, N, global_scale, - TILE_SIZE=tile_size, - BLOCK_SIZE=block_size, - NUM_FP4_BLOCKS=num_fp4_blocks, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + **launch_kwargs, ) + y = y.reshape(x_shape).contiguous().to(dtype=x_dtype) return y From 36bdf82664e2e46aae63fa4bf5f70427c42b7bad Mon Sep 17 00:00:00 2001 From: mxin Date: Wed, 12 Nov 2025 21:52:57 -0800 Subject: [PATCH 2/3] save memory for bf16/fp16 by cast dtype in kernel Signed-off-by: mxin --- .../torch/quantization/triton/fp4_kernel.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 40a6ecc1c..93c61b658 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -27,6 +27,21 @@ __all__ = ["fp4_fake_quant_block"] +_TORCH_TO_TL_DTYPE = { + torch.float32: tl.float32, + torch.float: tl.float32, + torch.float16: tl.float16, + torch.half: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +def _torch_dtype_to_tl(dtype: torch.dtype): + if dtype not in _TORCH_TO_TL_DTYPE: + raise ValueError(f"Unsupported dtype for fp4 fake quantization: {dtype}") + return _TORCH_TO_TL_DTYPE[dtype] + + @triton.jit def fp4_fake_quant_kernel( x_ptr, @@ -42,6 +57,7 @@ def fp4_fake_quant_kernel( TILE_M: tl.constexpr, TILE_N: tl.constexpr, NUM_FP4_BLOCKS: tl.constexpr, + OUT_DTYPE: tl.constexpr, ): """Applies FP4 fake quantization using block pointers for memory addressing.""" pid_m = tl.program_id(axis=0) @@ -119,7 +135,7 @@ def fp4_fake_quant_kernel( tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) - tl.store(y_block_ptr, tile_quant, boundary_check=(0, 1)) + tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) def fp4_fake_quant_block( @@ -151,7 +167,7 @@ def fp4_fake_quant_block( x = x.reshape(-1, x_shape[-1]).contiguous() M, N = x.shape - y = torch.empty_like(x, dtype=torch.float32) + y = torch.empty_like(x) stride_xm, stride_xn = x.stride() stride_ym, stride_yn = y.stride() @@ -169,6 +185,7 @@ def fp4_fake_quant_block( "TILE_M": tile_rows, "TILE_N": tile_cols_aligned, "NUM_FP4_BLOCKS": num_fp4_blocks, + "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), } if num_warps is not None: launch_kwargs["num_warps"] = num_warps @@ -187,7 +204,7 @@ def fp4_fake_quant_block( **launch_kwargs, ) - y = y.reshape(x_shape).contiguous().to(dtype=x_dtype) + y = y.reshape(x_shape) return y From 106da63d011b7795dccf66400515b0098a45be28 Mon Sep 17 00:00:00 2001 From: mxin Date: Thu, 13 Nov 2025 18:09:39 -0800 Subject: [PATCH 3/3] use view Signed-off-by: mxin --- modelopt/torch/quantization/triton/fp4_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 93c61b658..f2f9bd077 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -204,7 +204,7 @@ def fp4_fake_quant_block( **launch_kwargs, ) - y = y.reshape(x_shape) + y = y.view(*x_shape) return y