From 577a570e0bfc9be98c56840db269b2d6b296e36a Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 13:10:38 -0800 Subject: [PATCH 01/19] benchmark for triton_fp8_blockwise_act_quant_transposed_lhs against naive torch implementation --- ...ton_fp8_blockwise_act_quant_transposed_lhs | 315 ++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs b/benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs new file mode 100644 index 0000000000..4415971883 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs @@ -0,0 +1,315 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_transposed_lhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + naive_us: float + triton_us: float + # mem bw + naive_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + + Note: For transposed_lhs, M must be divisible by block_size + """ + # Llama-style shapes: various batch*seq_len sizes with typical hidden dims + # Ensuring M is divisible by block_size (128) + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify M is divisible by block_size + if shape[0] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def naive_fp8_blockwise_quant_transposed( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive PyTorch reference implementation for blockwise FP8 quantization with transpose. + + This version: + 1. Computes column-wise scales (along dimension 0) + 2. Outputs transposed quantized tensor (K, M) in row-major format + 3. Outputs scales in shape (K, M//block_size) + + Args: + x: Input tensor of shape (M, K) + block_size: Number of elements per block along M dimension + + Returns: + y: Transposed quantized tensor in FP8, shape (K, M) in row-major + s: Reciprocal scales in column-major format (K, M//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.size(0) % block_size == 0, "M must be divisible by block_size" + + M, K = x.size() + num_blocks = M // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (num_blocks, block_size, K) for block-wise operations along M + x_reshaped = x.view(num_blocks, block_size, K) + + # Compute max absolute value per block along dimension 1 (block_size) + # Result shape: (num_blocks, K) + amax = torch.clamp( + x_reshaped.abs().amax(dim=1).to(torch.float64), + min=eps, + max=float('inf') + ) + + # Compute scales (num_blocks, K) -> (num_blocks, 1, K) for broadcasting + scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) + + # Quantize + y_reshaped = x_reshaped * scale + y_reshaped = torch.clamp( + y_reshaped, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Reshape back to (M, K) then transpose to (K, M) + y = y_reshaped.view(M, K).t().contiguous().to(torch.float8_e4m3fn) + + # Compute reciprocal scales - explicitly cast to float32 + reciprocal_scale = (1.0 / scale.squeeze(1)).to(torch.float32) + # reciprocal_scale is (num_blocks, K), need to transpose to (K, num_blocks) + reciprocal_scale = reciprocal_scale.t().contiguous() + + # Convert to column-major using as_strided (matching Triton kernel output) + s = x.new_empty(K, num_blocks, dtype=torch.float32).as_strided( + (K, num_blocks), + (1, K), # Column-major strides + ) + s.copy_(reciprocal_scale) + + return y, s + + def verify_outputs( + y_naive: torch.Tensor, + s_naive: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + input_tensor: torch.Tensor, + block_size: int, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and naive implementations produce similar results.""" + + # Verify output shapes + M, K = input_tensor.shape + expected_y_shape = (K, M) + expected_s_shape = (K, M // block_size) + + assert y_naive.shape == expected_y_shape, f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" + assert y_triton.shape == expected_y_shape, f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" + assert s_naive.shape == expected_s_shape, f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" + assert s_triton.shape == expected_s_shape, f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" + + # Convert FP8 back to float for comparison + y_naive_float = y_naive.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + # Check quantized values are close + if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print( + f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]") + print( + f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]") + + # ROBUST FIX: Handle potential dtype mismatches from torch.compile + # Convert both scales to float32 before any operations + if s_naive.dtype != torch.float32: + print( + f"INFO: Converting naive scales from {s_naive.dtype} to float32") + s_naive = s_naive.to(torch.float32) + + if s_triton.dtype != torch.float32: + print( + f"INFO: Converting Triton scales from {s_triton.dtype} to float32") + s_triton = s_triton.to(torch.float32) + + # Check scales are close + # Note: scales are in column-major format, need to read them correctly + s_naive_rowmajor = s_naive.as_strided( + s_naive.shape, (s_naive.shape[1], 1)) + s_triton_rowmajor = s_triton.as_strided( + s_triton.shape, (s_triton.shape[1], 1)) + + if not torch.allclose(s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol): + max_diff = (s_naive_rowmajor - + s_triton_rowmajor).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") + print( + f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]") + print( + f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]") + + input_tensor = torch.randn( + M, K, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark naive implementation + naive_impl_c = torch.compile(naive_fp8_blockwise_quant_transposed) + + # Benchmark after warmup + y_naive, s_naive = naive_impl_c(input_tensor, block_size) + naive_time_us = benchmark_cuda_function_in_microseconds( + naive_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation + triton_impl_c = torch.compile( + triton_fp8_blockwise_act_quant_transposed_lhs) + + # Benchmark after warmup + y_triton, s_triton = triton_impl_c(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_impl_c, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_naive, s_naive, y_triton, + s_triton, input_tensor, block_size) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = 4 # float32 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + + s_triton.numel() * bytes_per_scale_el + ) + + naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + naive_us=naive_time_us, + triton_us=triton_time_us, + naive_gbps=naive_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "naive_us", + "triton_us", + "speedup", + "naive_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.naive_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.naive_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "="*80) + print("BENCHMARK RESULTS - Transposed LHS Quantization") + print("="*80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() From e5c86015d93d50d3757c5de1e509268e3f4337c7 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 13:12:07 -0800 Subject: [PATCH 02/19] benchmark for triton_fp8_blockwise_act_quant_lhs against naive implementation torch_blockwise_scale_act_quant_lhs from existing blockwise_fp8_training/kernels --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py new file mode 100644 index 0000000000..0f94ebb590 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + torch_blockwise_scale_act_quant_lhs, + triton_fp8_blockwise_act_quant_lhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + naive_us: float + triton_us: float + # mem bw + naive_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + """ + # Llama-style shapes: various batch*seq_len sizes with typical hidden dims + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def naive_fp8_blockwise_quant( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive PyTorch reference implementation for blockwise FP8 quantization. + + Quantizes along dimension 1 (K) with blocks of size block_size. + Each row gets K//block_size scale factors. + + Args: + x: Input tensor of shape (M, K) + block_size: Number of elements per block + + Returns: + y: Quantized tensor in FP8 (M, K) + s: Reciprocal scales in column-major format (M, K//block_size) + """ + + M, K = x.size() + + # Reshape to (M, K) where K is treated as multiple tile_size blocks + y, s_reciprocal = torch_blockwise_scale_act_quant_lhs( + x, tile_size=block_size) + + # Convert scales from row-major to column-major format to match Triton kernel + num_blocks = K // block_size + s = x.new_empty(M, num_blocks, dtype=torch.float32).as_strided( + (M, num_blocks), + (1, M), # Column-major strides + ) + s.copy_(s_reciprocal) + + return y, s + + def verify_outputs( + y_naive: torch.Tensor, + s_naive: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + input_tensor: torch.Tensor, + block_size: int, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and naive implementations produce similar results.""" + + # Convert FP8 back to float for comparison + y_naive_float = y_naive.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + # Check quantized values are close + if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print( + f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + ) + print( + f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + ) + + # ROBUST FIX: Handle potential dtype mismatches from torch.compile + # Convert both scales to float32 before any operations + if s_naive.dtype != torch.float32: + print( + f"INFO: Converting naive scales from {s_naive.dtype} to float32") + s_naive = s_naive.to(torch.float32) + + if s_triton.dtype != torch.float32: + print( + f"INFO: Converting Triton scales from {s_triton.dtype} to float32") + s_triton = s_triton.to(torch.float32) + + # Check scales are close + # Note: scales are in column-major format, need to read them correctly + s_naive_rowmajor = s_naive.as_strided( + s_naive.shape, (s_naive.shape[1], 1)) + s_triton_rowmajor = s_triton.as_strided( + s_triton.shape, (s_triton.shape[1], 1)) + + if not torch.allclose( + s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol + ): + max_diff = (s_naive_rowmajor - + s_triton_rowmajor).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") + print( + f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + ) + print( + f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + ) + + input_tensor = torch.randn( + M, + K, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark naive implementation + naive_impl_c = torch.compile(naive_fp8_blockwise_quant) + y_naive, s_naive = naive_impl_c(input_tensor, block_size) + naive_time_us = benchmark_cuda_function_in_microseconds( + naive_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation + triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs) + y_triton, s_triton = triton_impl_c(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_impl_c, + input_tensor, + block_size, + ) + + # Verify correctness (optional, can comment out for pure benchmarking) + verify_outputs(y_naive, s_naive, y_triton, + s_triton, input_tensor, block_size) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = 4 # float32 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + naive_us=naive_time_us, + triton_us=triton_time_us, + naive_gbps=naive_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "naive_us", + "triton_us", + "speedup", + "naive_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.naive_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.naive_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() From ac3b55087b2f90247ffeb0b56d75fec4833a5033 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 13:12:52 -0800 Subject: [PATCH 03/19] benchmark for triton_fp8_blockwise_act_quant_rhs against naive implementation --- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 273 ++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py new file mode 100644 index 0000000000..9b9007f5b9 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -0,0 +1,273 @@ +# rhs_benchmark.py +# Copyright (c) Meta Platforms... +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_rhs, # <- RHS kernel +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + naive_us: float + triton_us: float + # mem bw + naive_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + """ + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def naive_fp8_blockwise_quant( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive PyTorch reference implementation for RHS blockwise FP8 quantization. + + RHS semantics: + • Groups are (block_size x 1) along the M dimension (rows). + • y is returned in column-major layout (M, K) with strides (1, M). + • s has shape (ceil(M/block_size), K) in row-major (reciprocal scales). + """ + assert x.is_contiguous(), "Input must be contiguous" + + M, K = x.size() + M_blocks = (M + block_size - 1) // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Pad rows so we can reshape without a loop; then crop back. + pad_rows = M_blocks * block_size - M + if pad_rows: + x = torch.nn.functional.pad( + x, (0, 0, 0, pad_rows)) # pad rows at bottom + + # Reshape to (M_blocks, block_size, K) for block-wise operations along M + x_reshaped = x.view(M_blocks, block_size, K) + + # Compute max abs per column within each block -> (M_blocks, K) + amax = torch.clamp( + x_reshaped.abs().amax(dim=1).to(torch.float64), + min=eps, + max=float("inf"), + ) + + # Compute scales -> (M_blocks, 1, K) for broadcasting across rows in block + scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) + + # Quantize (still (M_blocks, block_size, K)) + y_reshaped = torch.clamp( + x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Back to (M_padded, K), then crop to (M, K) + y_rowmajor = y_reshaped.view(M_blocks * block_size, K)[:M, :].to( + torch.float8_e4m3fn + ) + + # y must be column-major per RHS kernel contract + y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided( + (M, K), (1, M)) + y.copy_(y_rowmajor) + + # Reciprocal scales (row-major) -> (M_blocks, K) + reciprocal_scale = (1.0 / scale.squeeze(1) + ).to(torch.float32) # (M_blocks, K) + s = reciprocal_scale # already row-major and correct shape + + return y, s + + def verify_outputs( + y_naive: torch.Tensor, + s_naive: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and naive implementations produce similar results.""" + + # Quantized tensors (both are column-major; convert to float to compare) + y_naive_float = y_naive.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print( + f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + ) + print( + f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + ) + + # Ensure float32 for scales + if s_naive.dtype != torch.float32: + s_naive = s_naive.to(torch.float32) + if s_triton.dtype != torch.float32: + s_triton = s_triton.to(torch.float32) + + # RHS: scales are already row-major (M_blocks, K) on both paths; compare directly + if not torch.allclose(s_naive, s_triton, rtol=rtol, atol=atol): + max_diff = (s_naive - s_triton).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") + print( + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + print( + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") + + input_tensor = torch.randn( + M, + K, + dtype=torch.bfloat16, + device=device, + ) + + # Compile once + naive_impl_c = torch.compile(naive_fp8_blockwise_quant) + + # Benchmark naive implementation + y_naive, s_naive = naive_impl_c(input_tensor, block_size) + naive_time_us = benchmark_cuda_function_in_microseconds( + naive_impl_c, + input_tensor, + block_size, + ) + + triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_rhs) + + # Benchmark Triton implementation + y_triton, s_triton = triton_impl_c(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_impl_c, + input_tensor, + block_size, + ) + + # Verify correctness (compare to naive) + verify_outputs(y_naive, s_naive, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = 4 # float32 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + naive_us=naive_time_us, + triton_us=triton_time_us, + naive_gbps=naive_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "naive_us", + "triton_us", + "speedup", + "naive_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.naive_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.naive_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking RHS"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS (RHS)") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() From ee3a26eac7e9fcdb13734f370b74b8f38900b41a Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 13:13:21 -0800 Subject: [PATCH 04/19] bench for triton_fp8_blockwise_weight_quant_rhs against naive torch implementation --- ...h_triton_fp8_blockwise_weight_quant_rhs.py | 348 ++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py new file mode 100644 index 0000000000..ab7cd16cd8 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_weight_quant_rhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, N) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + naive_us: float + triton_us: float + # mem bw + naive_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical weight matrix shapes. + Format: (hidden_dim, hidden_dim) for square matrices or (hidden_dim_in, hidden_dim_out) + + Note: Both M and N must be divisible by block_size (128) + """ + # Common weight matrix shapes in transformers + # Format: (in_features, out_features) for weight matrices + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify both dimensions are divisible by block_size + if shape[0] % block_size == 0 and shape[1] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + """ + Run benchmark experiment comparing naive and Triton implementations. + """ + M, N = config.input_shape + block_size = config.block_size + + def naive_fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive PyTorch reference implementation for blockwise FP8 weight quantization. + + Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. + Outputs in column-major format for RHS operator. + + Args: + x: Input tensor of shape (M, N), row-major + block_size: Block size for quantization + + Returns: + y: Quantized tensor in FP8, shape (M, N), column-major format + s: Reciprocal scales in column-major format (M//block_size, N//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.dim() == 2, "Input must be 2D" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + "Both dimensions must be divisible by block_size" + ) + + M, N = x.size() + M_blocks = M // block_size + N_blocks = N // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (M_blocks, block_size, N_blocks, block_size) for block-wise operations + x_reshaped = x.view(M_blocks, block_size, N_blocks, block_size) + # Permute to (M_blocks, N_blocks, block_size, block_size) for easier block processing + x_blocks = x_reshaped.permute(0, 2, 1, 3) + + # Compute max absolute value per block (M_blocks, N_blocks) + amax = torch.clamp( + x_blocks.reshape(M_blocks, N_blocks, -1) + .abs() + .amax(dim=2) + .to(torch.float64), + min=eps, + max=float("inf"), + ) + + # Compute scales (M_blocks, N_blocks) + scale = (max_fp8_e4m3 / amax).to(torch.float32) + + # Broadcast scale for quantization (M_blocks, N_blocks, 1, 1) + scale_broadcast = scale.unsqueeze(2).unsqueeze(3) + + # Quantize + y_blocks = x_blocks * scale_broadcast + y_blocks = torch.clamp(y_blocks, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Reshape back and convert to FP8 + # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) + y_reshaped = y_blocks.permute(0, 2, 1, 3) + y_rowmajor = y_reshaped.reshape(M, N).contiguous() + + # Convert to column-major format + # y = y_rowmajor.to(torch.float8_e4m3fn) + # y = y.as_strided(y.size(), (1, M)) # Column-major strides + y = x.new_empty(M, N, dtype=torch.float8_e4m3fn).as_strided( + (M, N), (1, M)) + y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) + + # Compute reciprocal scales - explicitly cast to float32 + reciprocal_scale = (1.0 / scale).to(torch.float32) + + # Convert to column-major using as_strided + s = x.new_empty(M_blocks, N_blocks, dtype=torch.float32).as_strided( + (M_blocks, N_blocks), + (1, M_blocks), # Column-major strides + ) + s.copy_(reciprocal_scale) + + return y, s + + def verify_outputs( + y_naive: torch.Tensor, + s_naive: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + input_tensor: torch.Tensor, + block_size: int, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and naive implementations produce similar results.""" + + # Verify output shapes + M, N = input_tensor.shape + expected_y_shape = (M, N) + expected_s_shape = (M // block_size, N // block_size) + + assert y_naive.shape == expected_y_shape, ( + f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" + ) + assert y_triton.shape == expected_y_shape, ( + f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" + ) + assert s_naive.shape == expected_s_shape, ( + f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" + ) + assert s_triton.shape == expected_s_shape, ( + f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" + ) + + # Convert FP8 back to float for comparison + # Need to read column-major data correctly + y_naive_rowmajor = y_naive.as_strided(y_naive.shape, (N, 1)) + y_triton_rowmajor = y_triton.as_strided(y_triton.shape, (N, 1)) + + y_naive_float = y_naive_rowmajor.to(torch.float32) + y_triton_float = y_triton_rowmajor.to(torch.float32) + + # Check quantized values are close + if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print( + f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + ) + print( + f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + ) + + # Handle potential dtype mismatches from torch.compile + if s_naive.dtype != torch.float32: + print( + f"INFO: Converting naive scales from {s_naive.dtype} to float32") + s_naive = s_naive.to(torch.float32) + + if s_triton.dtype != torch.float32: + print( + f"INFO: Converting Triton scales from {s_triton.dtype} to float32") + s_triton = s_triton.to(torch.float32) + + # Check scales are close + # Scales are in column-major format, need to read them correctly + s_naive_rowmajor = s_naive.as_strided( + s_naive.shape, (s_naive.shape[1], 1)) + s_triton_rowmajor = s_triton.as_strided( + s_triton.shape, (s_triton.shape[1], 1)) + + if not torch.allclose( + s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol + ): + max_diff = (s_naive_rowmajor - + s_triton_rowmajor).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") + print( + f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + ) + print( + f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + ) + + # Create input tensor + input_tensor = torch.randn( + M, + N, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark naive implementation (torch.compile handles warmup) + naive_impl_c = torch.compile(naive_fp8_blockwise_weight_quant) + y_naive, s_naive = naive_impl_c(input_tensor, block_size) + naive_time_us = benchmark_cuda_function_in_microseconds( + naive_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation (torch.compile handles warmup) + triton_impl_c = torch.compile(triton_fp8_blockwise_weight_quant_rhs) + y_triton, s_triton = triton_impl_c(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_impl_c, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_naive, s_naive, y_triton, + s_triton, input_tensor, block_size) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = 4 # float32 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + naive_us=naive_time_us, + triton_us=triton_time_us, + naive_gbps=naive_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + """Print benchmark results in a formatted table.""" + headers = [ + "input_shape (M, N)", + "block_size", + "naive_us", + "triton_us", + "speedup", + "naive_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.naive_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.naive_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + """Main benchmark execution.""" + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS - RHS Weight Quantization") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() From 0bc1597d6dafabaf7cee2ac2de833bb4751e2261 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 13:13:37 -0800 Subject: [PATCH 05/19] bench for triton_fp8_blockwise_weight_quant_transposed_rhs against naive torch --- ...8_blockwise_weight_quant_transposed_rhs.py | 344 ++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py new file mode 100644 index 0000000000..395892176a --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_weight_quant_transposed_rhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, N) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + naive_us: float + triton_us: float + # mem bw + naive_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical weight matrix shapes. + Format: (hidden_dim, hidden_dim) for square matrices or (hidden_dim_in, hidden_dim_out) + + Note: Both M and N must be divisible by block_size (128) + """ + # Common weight matrix shapes in transformers + # Format: (in_features, out_features) for weight matrices + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify both dimensions are divisible by block_size + if shape[0] % block_size == 0 and shape[1] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + """ + Run benchmark experiment comparing naive and Triton implementations. + """ + M, N = config.input_shape + block_size = config.block_size + + def naive_fp8_blockwise_weight_quant_transposed( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Naive PyTorch reference implementation for blockwise FP8 weight quantization with transpose. + + Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. + Outputs transposed tensor (N, M) in column-major format for RHS operator. + + Args: + x: Input tensor of shape (M, N), row-major + block_size: Block size for quantization + + Returns: + y: Transposed quantized tensor in FP8, shape (N, M), column-major format + s: Reciprocal scales in column-major format (N//block_size, M//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.dim() == 2, "Input must be 2D" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + "Both dimensions must be divisible by block_size" + ) + + M, N = x.size() + M_blocks = M // block_size + N_blocks = N // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (M_blocks, block_size, N_blocks, block_size) + x_reshaped = x.view(M_blocks, block_size, N_blocks, block_size) + # Permute to (M_blocks, N_blocks, block_size, block_size) for easier block processing + x_blocks = x_reshaped.permute(0, 2, 1, 3) + + # Compute max absolute value per block (M_blocks, N_blocks) + amax = torch.clamp( + x_blocks.abs().amax(dim=(2, 3)).to(torch.float64), min=eps, max=float("inf") + ) + + # Compute scales (M_blocks, N_blocks) + scale = (max_fp8_e4m3 / amax).to(torch.float32) + + # Broadcast scale for quantization (M_blocks, N_blocks, 1, 1) + scale_broadcast = scale[:, :, None, None] + + # Quantize + y_blocks = x_blocks * scale_broadcast + y_blocks = torch.clamp(y_blocks, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) + y_reshaped = y_blocks.permute(0, 2, 1, 3) + # Reshape to (M, N) then transpose to (N, M) + y_rowmajor = y_reshaped.reshape(M, N).t().contiguous() + + # Convert to FP8 and create column-major output (matching Triton kernel) + y = x.new_empty(N, M, dtype=torch.float8_e4m3fn).as_strided( + (N, M), (1, N)) + y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) + + # Compute reciprocal scales + reciprocal_scale = (1.0 / scale).to(torch.float32) + # Transpose scale matrix to match output dimensions: (M_blocks, N_blocks) -> (N_blocks, M_blocks) + reciprocal_scale = reciprocal_scale.t().contiguous() + + # Convert to column-major using as_strided + s = x.new_empty(N_blocks, M_blocks, dtype=torch.float32).as_strided( + (N_blocks, M_blocks), + (1, N_blocks), # Column-major strides + ) + s.copy_(reciprocal_scale) + + return y, s + + def verify_outputs( + y_naive: torch.Tensor, + s_naive: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + input_tensor: torch.Tensor, + block_size: int, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and naive implementations produce similar results.""" + + # Verify output shapes + M, N = input_tensor.shape + expected_y_shape = (N, M) + expected_s_shape = (N // block_size, M // block_size) + + assert y_naive.shape == expected_y_shape, ( + f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" + ) + assert y_triton.shape == expected_y_shape, ( + f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" + ) + assert s_naive.shape == expected_s_shape, ( + f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" + ) + assert s_triton.shape == expected_s_shape, ( + f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" + ) + + # Convert FP8 back to float for comparison + # Need to read column-major data correctly - y is (N, M) with strides (1, N) + y_naive_rowmajor = y_naive.as_strided(y_naive.shape, (M, 1)) + y_triton_rowmajor = y_triton.as_strided(y_triton.shape, (M, 1)) + + y_naive_float = y_naive_rowmajor.to(torch.float32) + y_triton_float = y_triton_rowmajor.to(torch.float32) + + # Check quantized values are close + if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print( + f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + ) + print( + f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + ) + + # Handle potential dtype mismatches from torch.compile + if s_naive.dtype != torch.float32: + print( + f"INFO: Converting naive scales from {s_naive.dtype} to float32") + s_naive = s_naive.to(torch.float32) + + if s_triton.dtype != torch.float32: + print( + f"INFO: Converting Triton scales from {s_triton.dtype} to float32") + s_triton = s_triton.to(torch.float32) + + # Check scales are close + # Scales are in column-major format, need to read them correctly + s_naive_rowmajor = s_naive.as_strided( + s_naive.shape, (s_naive.shape[1], 1)) + s_triton_rowmajor = s_triton.as_strided( + s_triton.shape, (s_triton.shape[1], 1)) + + if not torch.allclose( + s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol + ): + max_diff = (s_naive_rowmajor - + s_triton_rowmajor).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") + print( + f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + ) + print( + f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + ) + + # Create input tensor + input_tensor = torch.randn( + M, + N, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark naive implementation (torch.compile handles warmup) + naive_impl_c = torch.compile(naive_fp8_blockwise_weight_quant_transposed) + y_naive, s_naive = naive_impl_c(input_tensor, block_size) + naive_time_us = benchmark_cuda_function_in_microseconds( + naive_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation (torch.compile handles warmup) + triton_impl_c = torch.compile( + triton_fp8_blockwise_weight_quant_transposed_rhs) + y_triton, s_triton = triton_impl_c(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_impl_c, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_naive, s_naive, y_triton, + s_triton, input_tensor, block_size) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = 4 # float32 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + naive_us=naive_time_us, + triton_us=triton_time_us, + naive_gbps=naive_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + """Print benchmark results in a formatted table.""" + headers = [ + "input_shape (M, N)", + "block_size", + "naive_us", + "triton_us", + "speedup", + "naive_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.naive_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.naive_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + """Main benchmark execution.""" + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS - Transposed RHS Weight Quantization") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() From a36bb48ea7edaa30db5d06be11a8ab63059d4a6e Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 14:26:44 -0800 Subject: [PATCH 06/19] removed extra space from file name --- ...ed_lhs => bench_triton_fp8_blockwise_act_quant_transposed_lhs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmarks/prototype/blockwise_fp8_training/{bench_ triton_fp8_blockwise_act_quant_transposed_lhs => bench_triton_fp8_blockwise_act_quant_transposed_lhs} (100%) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs similarity index 100% rename from benchmarks/prototype/blockwise_fp8_training/bench_ triton_fp8_blockwise_act_quant_transposed_lhs rename to benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs From 0b8b05e9f70d730d1ca32ed39d5e6e938068bcef Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 15:11:33 -0800 Subject: [PATCH 07/19] Flipped mem layout of scales to streamline the LHS activation quantization --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 36 +------ .../blockwise_fp8_training/kernels.py | 96 ++++++++++++------- 2 files changed, 65 insertions(+), 67 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index 0f94ebb590..b1b843aad9 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -79,40 +79,6 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: M, K = config.input_shape block_size = config.block_size - def naive_fp8_blockwise_quant( - x: torch.Tensor, block_size: int = 128 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Naive PyTorch reference implementation for blockwise FP8 quantization. - - Quantizes along dimension 1 (K) with blocks of size block_size. - Each row gets K//block_size scale factors. - - Args: - x: Input tensor of shape (M, K) - block_size: Number of elements per block - - Returns: - y: Quantized tensor in FP8 (M, K) - s: Reciprocal scales in column-major format (M, K//block_size) - """ - - M, K = x.size() - - # Reshape to (M, K) where K is treated as multiple tile_size blocks - y, s_reciprocal = torch_blockwise_scale_act_quant_lhs( - x, tile_size=block_size) - - # Convert scales from row-major to column-major format to match Triton kernel - num_blocks = K // block_size - s = x.new_empty(M, num_blocks, dtype=torch.float32).as_strided( - (M, num_blocks), - (1, M), # Column-major strides - ) - s.copy_(s_reciprocal) - - return y, s - def verify_outputs( y_naive: torch.Tensor, s_naive: torch.Tensor, @@ -180,7 +146,7 @@ def verify_outputs( ) # Benchmark naive implementation - naive_impl_c = torch.compile(naive_fp8_blockwise_quant) + naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) y_naive, s_naive = naive_impl_c(input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( naive_impl_c, diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index 3f82407d40..397fe515bb 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -91,7 +91,8 @@ def triton_fp8_gemm_1x128_128x128_kernel( b_ptrs += BLOCK_SIZE_K * b_stride_dim_0 c = accumulator.to(c_ptr.dtype.element_ty) - c_ptrs = c_ptr + offs_m[:, None] * c_stride_dim_0 + offs_n[None, :] * c_stride_dim_1 + c_ptrs = c_ptr + offs_m[:, None] * \ + c_stride_dim_0 + offs_n[None, :] * c_stride_dim_1 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @@ -116,7 +117,8 @@ def triton_fp8_gemm_1x128_128x128( K = a.size(1) N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - grid = lambda META: ( + + def grid(META): return ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), ) @@ -231,7 +233,8 @@ def triton_fp8_gemm_1x128_128x1( K = a.size(1) N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - grid = lambda META: ( + + def grid(META): return ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), ) @@ -304,7 +307,8 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( # Load (num_groups x block_size) tile of x, where input is row major m_offs = pid_m * NUM_GROUPS + tl.arange(0, NUM_GROUPS) k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + \ + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -313,13 +317,16 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( min_fp8_e4m3 = -448.0 # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1) - amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, + max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32)[:, None] y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( + y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 + y_offs = m_offs[:, None] * y_stride_dim_0 + \ + k_offs[None, :] * y_stride_dim_1 y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -350,7 +357,8 @@ def triton_fp8_blockwise_act_quant_lhs( (M, K // block_size), (1, M), ) - grid = lambda meta: ( + + def grid(meta): return ( triton.cdiv(M, meta["NUM_GROUPS"]), triton.cdiv(K, meta["BLOCK_SIZE"]), ) @@ -398,7 +406,8 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( # to facilitate coalesced gmem accesses and improve efficiency. m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) - x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + \ + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -407,13 +416,16 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( min_fp8_e4m3 = -448.0 # Column-wise scales for RHS operand, shape (1, block_size) - amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, + max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32)[None, :] y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( + y_ptr.dtype.element_ty) # Write output to column major format - y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 + y_offs = m_offs[:, None] * y_stride_dim_0 + \ + k_offs[None, :] * y_stride_dim_1 y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -443,7 +455,7 @@ def triton_fp8_blockwise_act_quant_rhs( y = y.as_strided(y.size(), (1, y.size(0))) s = x.new_empty(M_blocks, K, dtype=torch.float32) - grid = lambda meta: ( + def grid(meta): return ( triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(K, meta["NUM_GROUPS"]), ) @@ -496,7 +508,8 @@ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel( # which will fail to launch for large tensors, due to max block number of 65535. m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) - x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + \ + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -505,13 +518,16 @@ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel( min_fp8_e4m3 = -448.0 # Compute amax across dim 0 (column-wise). - amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, + max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( + y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = k_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 + y_offs = k_offs[:, None] * y_stride_dim_0 + \ + m_offs[None, :] * y_stride_dim_1 y_mask = (k_offs[:, None] < K) & (m_offs[None, :] < M) tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) @@ -549,7 +565,8 @@ def triton_fp8_blockwise_act_quant_transposed_lhs( (K, M_blocks), # shape (1, K), # stride ) - grid = lambda meta: ( + + def grid(meta): return ( triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(K, meta["NUM_GROUPS"]), ) @@ -596,20 +613,24 @@ def triton_fp8_blockwise_weight_quant_rhs_kernel( offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # Load (block_size x block_size) block of x, where input is row major - x_offs = offs_m[:, None] * x_stride_dim_0 + offs_n[None, :] * x_stride_dim_1 + x_offs = offs_m[:, None] * x_stride_dim_0 + \ + offs_n[None, :] * x_stride_dim_1 x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + x_offs, mask=x_mask) # Scale the data max_fp8_e4m3 = 448.0 min_fp8_e4m3 = -448.0 - amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, + max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( + y_ptr.dtype.element_ty) # Store output in column major format - y_offs = offs_m[:, None] * y_stride_dim_0 + offs_n[None, :] * y_stride_dim_1 + y_offs = offs_m[:, None] * y_stride_dim_0 + \ + offs_n[None, :] * y_stride_dim_1 y_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -639,7 +660,8 @@ def triton_fp8_blockwise_weight_quant_rhs( (M_blocks, N_blocks), # shape (1, M_blocks), # stride ) - grid = lambda meta: ( + + def grid(meta): return ( triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]), ) @@ -697,27 +719,32 @@ def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel( # Load (block_size x block_size) block of input, where input is row major m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_offs = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x_offs = m_offs[:, None] * x_stride_dim_0 + n_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + \ + n_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (n_offs[None, :] < N) x = tl.load(x_ptr + x_offs, mask=x_mask).to(tl.float32) # Perform scaling max_fp8_e4m3 = 448.0 min_fp8_e4m3 = -448.0 - amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, + max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( + y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = n_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 + y_offs = n_offs[:, None] * y_stride_dim_0 + \ + m_offs[None, :] * y_stride_dim_1 y_mask = (n_offs[:, None] < N) & (m_offs[None, :] < M) tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) # Write reciprocal scales scale_m = pid_m scale_k = pid_n - scale_offs = scale_k[:, None] * s_stride_dim_0 + scale_m[None, :] * s_stride_dim_1 + scale_offs = scale_k[:, None] * s_stride_dim_0 + \ + scale_m[None, :] * s_stride_dim_1 scale_mask = (scale_k[:, None] < N // BLOCK_SIZE) & ( scale_m[None, :] < M // BLOCK_SIZE ) @@ -744,7 +771,8 @@ def triton_fp8_blockwise_weight_quant_transposed_rhs( (n_blocks, m_blocks), # shape (1, n_blocks), # stride ) - grid = lambda meta: ( + + def grid(meta): return ( triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]), ) @@ -789,11 +817,13 @@ def torch_blockwise_scale_act_quant_lhs(x, tile_size=128): s = (fp8_dtype_max / x_amax).to(torch.float32) # Apply scale and clamp - x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + x = (x * s).clamp(min=fp8_dtype_min, + max=fp8_dtype_max).to(torch.float8_e4m3fn) # Reshape quantized output back to original shape and reshape scales accordingly x = x.reshape(*orig_shape) s = s.reshape(orig_shape[0], -1).to(torch.float) + s = s.transpose(-2, -1).contiguous().transpose(-2, -1) # Return output tensor and reciprocal scale return x, 1.0 / s @@ -831,7 +861,8 @@ def torch_blockwise_scale_act_quant_rhs( x_col = x_blocks[:, :, k] # (num_blocks_m, block_size) # Compute absolute max for each block - amax = torch.abs(x_col).max(dim=1, keepdim=True)[0] # (num_blocks_m, 1) + amax = torch.abs(x_col).max(dim=1, keepdim=True)[ + 0] # (num_blocks_m, 1) # Clamp to avoid division by zero amax = torch.clamp(amax, min=eps).to(torch.float64) @@ -889,7 +920,8 @@ def torch_blockwise_scale_weight_quant(x, tile_size=128): s = (fp8_dtype_max / x_amax).to(torch.float32) # Apply scale and clamp - x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + x = (x * s).clamp(min=fp8_dtype_min, + max=fp8_dtype_max).to(torch.float8_e4m3fn) # Reshape quantized output and scales back to 2D x = x.reshape(t_h, t_w, tile_size, tile_size) From 066b34615de04251b9863cee36b14542bbde886e Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 16:09:34 -0800 Subject: [PATCH 08/19] updates the LHS act --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index b1b843aad9..1381352cad 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -108,35 +108,41 @@ def verify_outputs( # ROBUST FIX: Handle potential dtype mismatches from torch.compile # Convert both scales to float32 before any operations - if s_naive.dtype != torch.float32: - print( - f"INFO: Converting naive scales from {s_naive.dtype} to float32") - s_naive = s_naive.to(torch.float32) + # if s_naive.dtype != torch.float32: + # print( + # f"INFO: Converting naive scales from {s_naive.dtype} to float32") + # s_naive = s_naive.to(torch.float32) - if s_triton.dtype != torch.float32: - print( - f"INFO: Converting Triton scales from {s_triton.dtype} to float32") - s_triton = s_triton.to(torch.float32) + # if s_triton.dtype != torch.float32: + # print( + # f"INFO: Converting Triton scales from {s_triton.dtype} to float32") + # s_triton = s_triton.to(torch.float32) # Check scales are close # Note: scales are in column-major format, need to read them correctly - s_naive_rowmajor = s_naive.as_strided( - s_naive.shape, (s_naive.shape[1], 1)) - s_triton_rowmajor = s_triton.as_strided( - s_triton.shape, (s_triton.shape[1], 1)) - - if not torch.allclose( - s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol - ): - max_diff = (s_naive_rowmajor - - s_triton_rowmajor).abs().max().item() + # s_naive_rowmajor = s_naive.as_strided( + # s_naive.shape, (s_naive.shape[1], 1)) + # s_triton_rowmajor = s_triton.as_strided( + # s_triton.shape, (s_triton.shape[1], 1)) + + try: + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) + except AssertionError as e: + max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" ) print( - f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" ) + print(f" Error details: {e}") input_tensor = torch.randn( M, @@ -146,10 +152,11 @@ def verify_outputs( ) # Benchmark naive implementation - naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) - y_naive, s_naive = naive_impl_c(input_tensor, block_size) + # naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) + y_naive, s_naive = torch_blockwise_scale_act_quant_lhs( + input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + torch_blockwise_scale_act_quant_lhs, input_tensor, block_size, ) @@ -168,8 +175,8 @@ def verify_outputs( s_triton, input_tensor, block_size) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_input_el = torch.finfo(torch.float32).bits / 8 + bytes_per_output_el = torch.finfo(torch.float32).bits / 8 bytes_per_scale_el = 4 # float32 read_bytes = input_tensor.numel() * bytes_per_input_el From 4ad066a350dffe3f4a9c247accf5a4dceb5a713f Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 16:20:18 -0800 Subject: [PATCH 09/19] output bytes calc corrected --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index 1381352cad..b7e11409ac 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -96,34 +96,24 @@ def verify_outputs( y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + try: + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) + except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + f" Naive scale range: [{y_naive_float.min():.6f}, {y_naive_float.max():.6f}]" ) print( - f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + f" Triton scale range: [{y_triton_float.min():.6f}, {y_triton_float.max():.6f}]" ) - - # ROBUST FIX: Handle potential dtype mismatches from torch.compile - # Convert both scales to float32 before any operations - # if s_naive.dtype != torch.float32: - # print( - # f"INFO: Converting naive scales from {s_naive.dtype} to float32") - # s_naive = s_naive.to(torch.float32) - - # if s_triton.dtype != torch.float32: - # print( - # f"INFO: Converting Triton scales from {s_triton.dtype} to float32") - # s_triton = s_triton.to(torch.float32) - - # Check scales are close - # Note: scales are in column-major format, need to read them correctly - # s_naive_rowmajor = s_naive.as_strided( - # s_naive.shape, (s_naive.shape[1], 1)) - # s_triton_rowmajor = s_triton.as_strided( - # s_triton.shape, (s_triton.shape[1], 1)) + print(f" Error details: {e}") try: torch.testing.assert_close( @@ -176,7 +166,7 @@ def verify_outputs( # Memory bandwidth calculations bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float32).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 bytes_per_scale_el = 4 # float32 read_bytes = input_tensor.numel() * bytes_per_input_el From e464ad5343cd62707599450b197c393e9ef9f8f4 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 16:23:47 -0800 Subject: [PATCH 10/19] minor changes to RHS activation bench --- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index 9b9007f5b9..ec02ff0156 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -148,30 +148,43 @@ def verify_outputs( y_naive_float = y_naive.to(torch.float32) y_triton_float = y_triton.to(torch.float32) - if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + try: + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) + except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) print( - f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) - - # Ensure float32 for scales - if s_naive.dtype != torch.float32: - s_naive = s_naive.to(torch.float32) - if s_triton.dtype != torch.float32: - s_triton = s_triton.to(torch.float32) - - # RHS: scales are already row-major (M_blocks, K) on both paths; compare directly - if not torch.allclose(s_naive, s_triton, rtol=rtol, atol=atol): + print(f" Error details: {e}") + + try: + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) + except AssertionError as e: max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" + ) print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" + ) + print(f" Error details: {e}") input_tensor = torch.randn( M, @@ -205,7 +218,7 @@ def verify_outputs( verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_input_el = torch.finfo(torch.float32).bits / 8 bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 bytes_per_scale_el = 4 # float32 From 36f34ca84edb9473061890d681d8d2817b7972a9 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 16:47:22 -0800 Subject: [PATCH 11/19] changes to testing --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 4 +- ...ton_fp8_blockwise_act_quant_transposed_lhs | 78 +++++++--------- ...h_triton_fp8_blockwise_weight_quant_rhs.py | 91 +++++++------------ ...8_blockwise_weight_quant_transposed_rhs.py | 88 +++++++----------- 4 files changed, 98 insertions(+), 163 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index b7e11409ac..16011540a2 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -84,8 +84,6 @@ def verify_outputs( s_naive: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, - input_tensor: torch.Tensor, - block_size: int, rtol: float = 1e-2, atol: float = 1e-2, ): @@ -162,7 +160,7 @@ def verify_outputs( # Verify correctness (optional, can comment out for pure benchmarking) verify_outputs(y_naive, s_naive, y_triton, - s_triton, input_tensor, block_size) + s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(torch.float32).bits / 8 diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs index 4415971883..3a572c8772 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs @@ -154,63 +154,55 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: s_naive: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, - input_tensor: torch.Tensor, - block_size: int, rtol: float = 1e-2, atol: float = 1e-2, ): """Verify that Triton and naive implementations produce similar results.""" - # Verify output shapes - M, K = input_tensor.shape - expected_y_shape = (K, M) - expected_s_shape = (K, M // block_size) - - assert y_naive.shape == expected_y_shape, f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" - assert y_triton.shape == expected_y_shape, f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" - assert s_naive.shape == expected_s_shape, f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" - assert s_triton.shape == expected_s_shape, f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" - # Convert FP8 back to float for comparison y_naive_float = y_naive.to(torch.float32) y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Quantized values differ! Max diff: {max_diff}") - print( - f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]") - print( - f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]") - # ROBUST FIX: Handle potential dtype mismatches from torch.compile - # Convert both scales to float32 before any operations - if s_naive.dtype != torch.float32: + # Check quantized values are close + try: + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) + except AssertionError as e: + max_diff = (y_naive_float - y_triton_float).abs().max().item() + print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f"INFO: Converting naive scales from {s_naive.dtype} to float32") - s_naive = s_naive.to(torch.float32) - - if s_triton.dtype != torch.float32: + f" Naive scale range: [{y_naive_float.min():.6f}, {y_naive_float.max():.6f}]" + ) print( - f"INFO: Converting Triton scales from {s_triton.dtype} to float32") - s_triton = s_triton.to(torch.float32) - - # Check scales are close - # Note: scales are in column-major format, need to read them correctly - s_naive_rowmajor = s_naive.as_strided( - s_naive.shape, (s_naive.shape[1], 1)) - s_triton_rowmajor = s_triton.as_strided( - s_triton.shape, (s_triton.shape[1], 1)) - - if not torch.allclose(s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol): - max_diff = (s_naive_rowmajor - - s_triton_rowmajor).abs().max().item() + f" Triton scale range: [{y_triton_float.min():.6f}, {y_triton_float.max():.6f}]" + ) + print(f" Error details: {e}") + + try: + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) + except AssertionError as e: + max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]") + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" + ) print( - f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]") + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" + ) + print(f" Error details: {e}") input_tensor = torch.randn( M, K, @@ -243,10 +235,10 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: # Verify correctness verify_outputs(y_naive, s_naive, y_triton, - s_triton, input_tensor, block_size) + s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_input_el = torch.finfo(torch.float32).bits / 8 bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 bytes_per_scale_el = 4 # float32 diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py index ab7cd16cd8..b87efc4f04 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -149,8 +149,6 @@ def naive_fp8_blockwise_weight_quant( y_rowmajor = y_reshaped.reshape(M, N).contiguous() # Convert to column-major format - # y = y_rowmajor.to(torch.float8_e4m3fn) - # y = y.as_strided(y.size(), (1, M)) # Column-major strides y = x.new_empty(M, N, dtype=torch.float8_e4m3fn).as_strided( (M, N), (1, M)) y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) @@ -172,80 +170,54 @@ def verify_outputs( s_naive: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, - input_tensor: torch.Tensor, - block_size: int, rtol: float = 1e-2, atol: float = 1e-2, ): """Verify that Triton and naive implementations produce similar results.""" - # Verify output shapes - M, N = input_tensor.shape - expected_y_shape = (M, N) - expected_s_shape = (M // block_size, N // block_size) - - assert y_naive.shape == expected_y_shape, ( - f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" - ) - assert y_triton.shape == expected_y_shape, ( - f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" - ) - assert s_naive.shape == expected_s_shape, ( - f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" - ) - assert s_triton.shape == expected_s_shape, ( - f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" - ) - # Convert FP8 back to float for comparison - # Need to read column-major data correctly - y_naive_rowmajor = y_naive.as_strided(y_naive.shape, (N, 1)) - y_triton_rowmajor = y_triton.as_strided(y_triton.shape, (N, 1)) - y_naive_float = y_naive_rowmajor.to(torch.float32) - y_triton_float = y_triton_rowmajor.to(torch.float32) + y_naive_float = y_naive.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + try: + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) + except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) print( - f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) - - # Handle potential dtype mismatches from torch.compile - if s_naive.dtype != torch.float32: - print( - f"INFO: Converting naive scales from {s_naive.dtype} to float32") - s_naive = s_naive.to(torch.float32) - - if s_triton.dtype != torch.float32: - print( - f"INFO: Converting Triton scales from {s_triton.dtype} to float32") - s_triton = s_triton.to(torch.float32) - - # Check scales are close - # Scales are in column-major format, need to read them correctly - s_naive_rowmajor = s_naive.as_strided( - s_naive.shape, (s_naive.shape[1], 1)) - s_triton_rowmajor = s_triton.as_strided( - s_triton.shape, (s_triton.shape[1], 1)) - - if not torch.allclose( - s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol - ): - max_diff = (s_naive_rowmajor - - s_triton_rowmajor).abs().max().item() + print(f" Error details: {e}") + + try: + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) + except AssertionError as e: + max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" ) print( - f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" ) + print(f" Error details: {e}") # Create input tensor input_tensor = torch.randn( @@ -274,11 +246,10 @@ def verify_outputs( ) # Verify correctness - verify_outputs(y_naive, s_naive, y_triton, - s_triton, input_tensor, block_size) + verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_input_el = torch.finfo(torch.float32).bits / 8 bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 bytes_per_scale_el = 4 # float32 diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py index 395892176a..d3a7b0fd27 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -167,80 +167,54 @@ def verify_outputs( s_naive: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, - input_tensor: torch.Tensor, - block_size: int, rtol: float = 1e-2, atol: float = 1e-2, ): """Verify that Triton and naive implementations produce similar results.""" - # Verify output shapes - M, N = input_tensor.shape - expected_y_shape = (N, M) - expected_s_shape = (N // block_size, M // block_size) - - assert y_naive.shape == expected_y_shape, ( - f"Naive y shape mismatch: {y_naive.shape} vs {expected_y_shape}" - ) - assert y_triton.shape == expected_y_shape, ( - f"Triton y shape mismatch: {y_triton.shape} vs {expected_y_shape}" - ) - assert s_naive.shape == expected_s_shape, ( - f"Naive s shape mismatch: {s_naive.shape} vs {expected_s_shape}" - ) - assert s_triton.shape == expected_s_shape, ( - f"Triton s shape mismatch: {s_triton.shape} vs {expected_s_shape}" - ) - # Convert FP8 back to float for comparison - # Need to read column-major data correctly - y is (N, M) with strides (1, N) - y_naive_rowmajor = y_naive.as_strided(y_naive.shape, (M, 1)) - y_triton_rowmajor = y_triton.as_strided(y_triton.shape, (M, 1)) - y_naive_float = y_naive_rowmajor.to(torch.float32) - y_triton_float = y_triton_rowmajor.to(torch.float32) + y_naive_float = y_naive.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): + try: + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) + except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Quantized values differ! Max diff: {max_diff}") + print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" + f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) print( - f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" + f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" ) - - # Handle potential dtype mismatches from torch.compile - if s_naive.dtype != torch.float32: - print( - f"INFO: Converting naive scales from {s_naive.dtype} to float32") - s_naive = s_naive.to(torch.float32) - - if s_triton.dtype != torch.float32: - print( - f"INFO: Converting Triton scales from {s_triton.dtype} to float32") - s_triton = s_triton.to(torch.float32) - - # Check scales are close - # Scales are in column-major format, need to read them correctly - s_naive_rowmajor = s_naive.as_strided( - s_naive.shape, (s_naive.shape[1], 1)) - s_triton_rowmajor = s_triton.as_strided( - s_triton.shape, (s_triton.shape[1], 1)) - - if not torch.allclose( - s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol - ): - max_diff = (s_naive_rowmajor - - s_triton_rowmajor).abs().max().item() + print(f" Error details: {e}") + + try: + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) + except AssertionError as e: + max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") print( - f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" + f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" ) print( - f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" + f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" ) + print(f" Error details: {e}") # Create input tensor input_tensor = torch.randn( @@ -271,10 +245,10 @@ def verify_outputs( # Verify correctness verify_outputs(y_naive, s_naive, y_triton, - s_triton, input_tensor, block_size) + s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_input_el = torch.finfo(torch.float32).bits / 8 bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 bytes_per_scale_el = 4 # float32 From 278cb70d3fa6b48245bd964a4695679c73378f82 Mon Sep 17 00:00:00 2001 From: agolajko Date: Thu, 6 Nov 2025 17:38:16 -0800 Subject: [PATCH 12/19] forgot to lint --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 19 +-- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 25 +-- ...h_triton_fp8_blockwise_weight_quant_rhs.py | 16 +- ...8_blockwise_weight_quant_transposed_rhs.py | 22 +-- .../blockwise_fp8_training/kernels.py | 144 ++++++++---------- 5 files changed, 92 insertions(+), 134 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index 16011540a2..dcef2e16c3 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -58,7 +58,6 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), - ] configs = [] @@ -100,7 +99,7 @@ def verify_outputs( y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations" + msg="Quantized values differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() @@ -119,17 +118,13 @@ def verify_outputs( s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations" + msg="Scales differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" - ) - print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" - ) + print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") print(f" Error details: {e}") input_tensor = torch.randn( @@ -141,8 +136,7 @@ def verify_outputs( # Benchmark naive implementation # naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) - y_naive, s_naive = torch_blockwise_scale_act_quant_lhs( - input_tensor, block_size) + y_naive, s_naive = torch_blockwise_scale_act_quant_lhs(input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( torch_blockwise_scale_act_quant_lhs, input_tensor, @@ -159,8 +153,7 @@ def verify_outputs( ) # Verify correctness (optional, can comment out for pure benchmarking) - verify_outputs(y_naive, s_naive, y_triton, - s_triton) + verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(torch.float32).bits / 8 diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index ec02ff0156..08310cfeb9 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -52,7 +52,6 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), - ] configs = [] @@ -97,8 +96,7 @@ def naive_fp8_blockwise_quant( # Pad rows so we can reshape without a loop; then crop back. pad_rows = M_blocks * block_size - M if pad_rows: - x = torch.nn.functional.pad( - x, (0, 0, 0, pad_rows)) # pad rows at bottom + x = torch.nn.functional.pad(x, (0, 0, 0, pad_rows)) # pad rows at bottom # Reshape to (M_blocks, block_size, K) for block-wise operations along M x_reshaped = x.view(M_blocks, block_size, K) @@ -114,8 +112,7 @@ def naive_fp8_blockwise_quant( scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) # Quantize (still (M_blocks, block_size, K)) - y_reshaped = torch.clamp( - x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) + y_reshaped = torch.clamp(x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) # Back to (M_padded, K), then crop to (M, K) y_rowmajor = y_reshaped.view(M_blocks * block_size, K)[:M, :].to( @@ -123,13 +120,11 @@ def naive_fp8_blockwise_quant( ) # y must be column-major per RHS kernel contract - y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided( - (M, K), (1, M)) + y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided((M, K), (1, M)) y.copy_(y_rowmajor) # Reciprocal scales (row-major) -> (M_blocks, K) - reciprocal_scale = (1.0 / scale.squeeze(1) - ).to(torch.float32) # (M_blocks, K) + reciprocal_scale = (1.0 / scale.squeeze(1)).to(torch.float32) # (M_blocks, K) s = reciprocal_scale # already row-major and correct shape return y, s @@ -154,7 +149,7 @@ def verify_outputs( y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations" + msg="Quantized values differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() @@ -173,17 +168,13 @@ def verify_outputs( s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations" + msg="Scales differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" - ) - print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" - ) + print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") print(f" Error details: {e}") input_tensor = torch.randn( diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py index b87efc4f04..8201f9dcd7 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -60,7 +60,6 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), - ] configs = [] @@ -149,8 +148,7 @@ def naive_fp8_blockwise_weight_quant( y_rowmajor = y_reshaped.reshape(M, N).contiguous() # Convert to column-major format - y = x.new_empty(M, N, dtype=torch.float8_e4m3fn).as_strided( - (M, N), (1, M)) + y = x.new_empty(M, N, dtype=torch.float8_e4m3fn).as_strided((M, N), (1, M)) y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) # Compute reciprocal scales - explicitly cast to float32 @@ -187,7 +185,7 @@ def verify_outputs( y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations" + msg="Quantized values differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() @@ -206,17 +204,13 @@ def verify_outputs( s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations" + msg="Scales differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" - ) - print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" - ) + print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") print(f" Error details: {e}") # Create input tensor diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py index d3a7b0fd27..d6eca41ce8 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -60,7 +60,6 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), - ] configs = [] @@ -144,8 +143,7 @@ def naive_fp8_blockwise_weight_quant_transposed( y_rowmajor = y_reshaped.reshape(M, N).t().contiguous() # Convert to FP8 and create column-major output (matching Triton kernel) - y = x.new_empty(N, M, dtype=torch.float8_e4m3fn).as_strided( - (N, M), (1, N)) + y = x.new_empty(N, M, dtype=torch.float8_e4m3fn).as_strided((N, M), (1, N)) y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) # Compute reciprocal scales @@ -184,7 +182,7 @@ def verify_outputs( y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations" + msg="Quantized values differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (y_naive_float - y_triton_float).abs().max().item() @@ -203,17 +201,13 @@ def verify_outputs( s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations" + msg="Scales differ between naive and Triton implementations", ) except AssertionError as e: max_diff = (s_naive - s_triton).abs().max().item() print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" - ) - print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" - ) + print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") + print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") print(f" Error details: {e}") # Create input tensor @@ -234,8 +228,7 @@ def verify_outputs( ) # Benchmark Triton implementation (torch.compile handles warmup) - triton_impl_c = torch.compile( - triton_fp8_blockwise_weight_quant_transposed_rhs) + triton_impl_c = torch.compile(triton_fp8_blockwise_weight_quant_transposed_rhs) y_triton, s_triton = triton_impl_c(input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( triton_impl_c, @@ -244,8 +237,7 @@ def verify_outputs( ) # Verify correctness - verify_outputs(y_naive, s_naive, y_triton, - s_triton) + verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(torch.float32).bits / 8 diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index 397fe515bb..2ceb839173 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -91,8 +91,7 @@ def triton_fp8_gemm_1x128_128x128_kernel( b_ptrs += BLOCK_SIZE_K * b_stride_dim_0 c = accumulator.to(c_ptr.dtype.element_ty) - c_ptrs = c_ptr + offs_m[:, None] * \ - c_stride_dim_0 + offs_n[None, :] * c_stride_dim_1 + c_ptrs = c_ptr + offs_m[:, None] * c_stride_dim_0 + offs_n[None, :] * c_stride_dim_1 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @@ -118,10 +117,12 @@ def triton_fp8_gemm_1x128_128x128( N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - def grid(META): return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x128_kernel)[grid]( a, a.stride(0), @@ -234,10 +235,12 @@ def triton_fp8_gemm_1x128_128x1( N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - def grid(META): return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x1_kernel)[grid]( a, a.stride(0), @@ -307,8 +310,7 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( # Load (num_groups x block_size) tile of x, where input is row major m_offs = pid_m * NUM_GROUPS + tl.arange(0, NUM_GROUPS) k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x_offs = m_offs[:, None] * x_stride_dim_0 + \ - k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -317,16 +319,13 @@ def triton_fp8_blockwise_act_quant_lhs_kernel( min_fp8_e4m3 = -448.0 # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1) - amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, - max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32)[:, None] y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( - y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = m_offs[:, None] * y_stride_dim_0 + \ - k_offs[None, :] * y_stride_dim_1 + y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -358,10 +357,12 @@ def triton_fp8_blockwise_act_quant_lhs( (1, M), ) - def grid(meta): return ( - triton.cdiv(M, meta["NUM_GROUPS"]), - triton.cdiv(K, meta["BLOCK_SIZE"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["NUM_GROUPS"]), + triton.cdiv(K, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_lhs_kernel)[grid]( x, x.stride(0), @@ -406,8 +407,7 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( # to facilitate coalesced gmem accesses and improve efficiency. m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) - x_offs = m_offs[:, None] * x_stride_dim_0 + \ - k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -416,16 +416,13 @@ def triton_fp8_blockwise_act_quant_rhs_kernel( min_fp8_e4m3 = -448.0 # Column-wise scales for RHS operand, shape (1, block_size) - amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, - max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32)[None, :] y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( - y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) # Write output to column major format - y_offs = m_offs[:, None] * y_stride_dim_0 + \ - k_offs[None, :] * y_stride_dim_1 + y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -455,10 +452,12 @@ def triton_fp8_blockwise_act_quant_rhs( y = y.as_strided(y.size(), (1, y.size(0))) s = x.new_empty(M_blocks, K, dtype=torch.float32) - def grid(meta): return ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(K, meta["NUM_GROUPS"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_rhs_kernel)[grid]( x, x.stride(0), @@ -508,8 +507,7 @@ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel( # which will fail to launch for large tensors, due to max block number of 65535. m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) - x_offs = m_offs[:, None] * x_stride_dim_0 + \ - k_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) x = tl.load(x_ptr + x_offs, mask=x_mask) @@ -518,16 +516,13 @@ def triton_fp8_blockwise_act_quant_transposed_lhs_kernel( min_fp8_e4m3 = -448.0 # Compute amax across dim 0 (column-wise). - amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, - max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( - y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = k_offs[:, None] * y_stride_dim_0 + \ - m_offs[None, :] * y_stride_dim_1 + y_offs = k_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 y_mask = (k_offs[:, None] < K) & (m_offs[None, :] < M) tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) @@ -566,10 +561,11 @@ def triton_fp8_blockwise_act_quant_transposed_lhs( (1, K), # stride ) - def grid(meta): return ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(K, meta["NUM_GROUPS"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) wrap_triton(triton_fp8_blockwise_act_quant_transposed_lhs_kernel)[grid]( x, @@ -613,24 +609,20 @@ def triton_fp8_blockwise_weight_quant_rhs_kernel( offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # Load (block_size x block_size) block of x, where input is row major - x_offs = offs_m[:, None] * x_stride_dim_0 + \ - offs_n[None, :] * x_stride_dim_1 + x_offs = offs_m[:, None] * x_stride_dim_0 + offs_n[None, :] * x_stride_dim_1 x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + x_offs, mask=x_mask) # Scale the data max_fp8_e4m3 = 448.0 min_fp8_e4m3 = -448.0 - amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, - max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( - y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) # Store output in column major format - y_offs = offs_m[:, None] * y_stride_dim_0 + \ - offs_n[None, :] * y_stride_dim_1 + y_offs = offs_m[:, None] * y_stride_dim_0 + offs_n[None, :] * y_stride_dim_1 y_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(y_ptr + y_offs, y, mask=y_mask) @@ -661,10 +653,12 @@ def triton_fp8_blockwise_weight_quant_rhs( (1, M_blocks), # stride ) - def grid(meta): return ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_rhs_kernel)[grid]( x, x.stride(0), @@ -719,32 +713,27 @@ def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel( # Load (block_size x block_size) block of input, where input is row major m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) n_offs = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x_offs = m_offs[:, None] * x_stride_dim_0 + \ - n_offs[None, :] * x_stride_dim_1 + x_offs = m_offs[:, None] * x_stride_dim_0 + n_offs[None, :] * x_stride_dim_1 x_mask = (m_offs[:, None] < M) & (n_offs[None, :] < N) x = tl.load(x_ptr + x_offs, mask=x_mask).to(tl.float32) # Perform scaling max_fp8_e4m3 = 448.0 min_fp8_e4m3 = -448.0 - amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, - max=float("inf")).to(tl.float64) + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) scale = (max_fp8_e4m3 / amax).to(tl.float32) y = x * scale - y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to( - y_ptr.dtype.element_ty) + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) # Write output to column major fomrat - y_offs = n_offs[:, None] * y_stride_dim_0 + \ - m_offs[None, :] * y_stride_dim_1 + y_offs = n_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 y_mask = (n_offs[:, None] < N) & (m_offs[None, :] < M) tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) # Write reciprocal scales scale_m = pid_m scale_k = pid_n - scale_offs = scale_k[:, None] * s_stride_dim_0 + \ - scale_m[None, :] * s_stride_dim_1 + scale_offs = scale_k[:, None] * s_stride_dim_0 + scale_m[None, :] * s_stride_dim_1 scale_mask = (scale_k[:, None] < N // BLOCK_SIZE) & ( scale_m[None, :] < M // BLOCK_SIZE ) @@ -772,10 +761,12 @@ def triton_fp8_blockwise_weight_quant_transposed_rhs( (1, n_blocks), # stride ) - def grid(meta): return ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_transposed_rhs_kernel)[grid]( x, x.stride(0), @@ -817,8 +808,7 @@ def torch_blockwise_scale_act_quant_lhs(x, tile_size=128): s = (fp8_dtype_max / x_amax).to(torch.float32) # Apply scale and clamp - x = (x * s).clamp(min=fp8_dtype_min, - max=fp8_dtype_max).to(torch.float8_e4m3fn) + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) # Reshape quantized output back to original shape and reshape scales accordingly x = x.reshape(*orig_shape) @@ -861,8 +851,7 @@ def torch_blockwise_scale_act_quant_rhs( x_col = x_blocks[:, :, k] # (num_blocks_m, block_size) # Compute absolute max for each block - amax = torch.abs(x_col).max(dim=1, keepdim=True)[ - 0] # (num_blocks_m, 1) + amax = torch.abs(x_col).max(dim=1, keepdim=True)[0] # (num_blocks_m, 1) # Clamp to avoid division by zero amax = torch.clamp(amax, min=eps).to(torch.float64) @@ -920,8 +909,7 @@ def torch_blockwise_scale_weight_quant(x, tile_size=128): s = (fp8_dtype_max / x_amax).to(torch.float32) # Apply scale and clamp - x = (x * s).clamp(min=fp8_dtype_min, - max=fp8_dtype_max).to(torch.float8_e4m3fn) + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) # Reshape quantized output and scales back to 2D x = x.reshape(t_h, t_w, tile_size, tile_size) From 873ba819c068f48ff4a44de21dfb24fe15d8f34f Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 07:18:42 -0800 Subject: [PATCH 13/19] Act LHS minor changes re PR comments --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 68 +++++++------------ 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index dcef2e16c3..38d139b335 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -93,39 +93,22 @@ def verify_outputs( y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - try: - torch.testing.assert_close( - y_naive_float, - y_triton_float, - rtol=rtol, - atol=atol, - msg="Quantized values differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{y_naive_float.min():.6f}, {y_naive_float.max():.6f}]" - ) - print( - f" Triton scale range: [{y_triton_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print(f" Error details: {e}") - - try: - torch.testing.assert_close( - s_naive, - s_triton, - rtol=rtol, - atol=atol, - msg="Scales differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (s_naive - s_triton).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") - print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") - print(f" Error details: {e}") + + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations", + ) + + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations", + ) input_tensor = torch.randn( M, @@ -135,19 +118,20 @@ def verify_outputs( ) # Benchmark naive implementation - # naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) - y_naive, s_naive = torch_blockwise_scale_act_quant_lhs(input_tensor, block_size) + naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) + y_naive, s_naive = naive_impl_c( + input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( - torch_blockwise_scale_act_quant_lhs, + naive_impl_c, input_tensor, block_size, ) # Benchmark Triton implementation - triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs) - y_triton, s_triton = triton_impl_c(input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs( + input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_impl_c, + triton_fp8_blockwise_act_quant_lhs, input_tensor, block_size, ) @@ -156,9 +140,9 @@ def verify_outputs( verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 - bytes_per_scale_el = 4 # float32 + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( From 8281e7b2795fff8cd3670bb3663078262aa0fc8f Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 07:29:46 -0800 Subject: [PATCH 14/19] Act RHS changes re PR comments --- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 73 +++++++------------ 1 file changed, 28 insertions(+), 45 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index 08310cfeb9..7595ec227f 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -96,7 +96,8 @@ def naive_fp8_blockwise_quant( # Pad rows so we can reshape without a loop; then crop back. pad_rows = M_blocks * block_size - M if pad_rows: - x = torch.nn.functional.pad(x, (0, 0, 0, pad_rows)) # pad rows at bottom + x = torch.nn.functional.pad( + x, (0, 0, 0, pad_rows)) # pad rows at bottom # Reshape to (M_blocks, block_size, K) for block-wise operations along M x_reshaped = x.view(M_blocks, block_size, K) @@ -112,7 +113,8 @@ def naive_fp8_blockwise_quant( scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) # Quantize (still (M_blocks, block_size, K)) - y_reshaped = torch.clamp(x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) + y_reshaped = torch.clamp( + x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) # Back to (M_padded, K), then crop to (M, K) y_rowmajor = y_reshaped.view(M_blocks * block_size, K)[:M, :].to( @@ -120,11 +122,11 @@ def naive_fp8_blockwise_quant( ) # y must be column-major per RHS kernel contract - y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided((M, K), (1, M)) - y.copy_(y_rowmajor) + y = y_rowmajor.t().contiguous().t() # Reciprocal scales (row-major) -> (M_blocks, K) - reciprocal_scale = (1.0 / scale.squeeze(1)).to(torch.float32) # (M_blocks, K) + reciprocal_scale = (1.0 / scale.squeeze(1) + ).to(torch.float32) # (M_blocks, K) s = reciprocal_scale # already row-major and correct shape return y, s @@ -143,39 +145,21 @@ def verify_outputs( y_naive_float = y_naive.to(torch.float32) y_triton_float = y_triton.to(torch.float32) - try: - torch.testing.assert_close( - y_naive_float, - y_triton_float, - rtol=rtol, - atol=atol, - msg="Quantized values differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print( - f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print(f" Error details: {e}") - - try: - torch.testing.assert_close( - s_naive, - s_triton, - rtol=rtol, - atol=atol, - msg="Scales differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (s_naive - s_triton).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") - print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") - print(f" Error details: {e}") + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations", + ) + + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations", + ) input_tensor = torch.randn( M, @@ -195,12 +179,11 @@ def verify_outputs( block_size, ) - triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_rhs) - # Benchmark Triton implementation - y_triton, s_triton = triton_impl_c(input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_act_quant_rhs( + input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_impl_c, + triton_fp8_blockwise_act_quant_rhs, input_tensor, block_size, ) @@ -209,9 +192,9 @@ def verify_outputs( verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 - bytes_per_scale_el = 4 # float32 + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( From 2175611eb62cf9fef16b2a9d049d84fa3e20e3fa Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 10:58:46 -0800 Subject: [PATCH 15/19] Weight transposed RHS fixes for pr comments --- ...8_blockwise_weight_quant_transposed_rhs.py | 79 +++++++------------ 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py index d6eca41ce8..1c30177b50 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -140,23 +140,20 @@ def naive_fp8_blockwise_weight_quant_transposed( # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) y_reshaped = y_blocks.permute(0, 2, 1, 3) # Reshape to (M, N) then transpose to (N, M) - y_rowmajor = y_reshaped.reshape(M, N).t().contiguous() + y_rowmajor = y_reshaped.reshape( + M, N).t() # Convert to FP8 and create column-major output (matching Triton kernel) - y = x.new_empty(N, M, dtype=torch.float8_e4m3fn).as_strided((N, M), (1, N)) - y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) + y = y_rowmajor.t().contiguous().t() + y = y_rowmajor.to(torch.float8_e4m3fn) # Compute reciprocal scales - reciprocal_scale = (1.0 / scale).to(torch.float32) + reciprocal_scale = 1.0 / scale # Transpose scale matrix to match output dimensions: (M_blocks, N_blocks) -> (N_blocks, M_blocks) - reciprocal_scale = reciprocal_scale.t().contiguous() + reciprocal_scale = reciprocal_scale.t() - # Convert to column-major using as_strided - s = x.new_empty(N_blocks, M_blocks, dtype=torch.float32).as_strided( - (N_blocks, M_blocks), - (1, N_blocks), # Column-major strides - ) - s.copy_(reciprocal_scale) + # Convert to col-major + s = reciprocal_scale.t().contiguous().t() return y, s @@ -176,39 +173,21 @@ def verify_outputs( y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - try: - torch.testing.assert_close( - y_naive_float, - y_triton_float, - rtol=rtol, - atol=atol, - msg="Quantized values differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print( - f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print(f" Error details: {e}") - - try: - torch.testing.assert_close( - s_naive, - s_triton, - rtol=rtol, - atol=atol, - msg="Scales differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (s_naive - s_triton).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") - print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") - print(f" Error details: {e}") + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations", + ) + + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations", + ) # Create input tensor input_tensor = torch.randn( @@ -228,10 +207,10 @@ def verify_outputs( ) # Benchmark Triton implementation (torch.compile handles warmup) - triton_impl_c = torch.compile(triton_fp8_blockwise_weight_quant_transposed_rhs) - y_triton, s_triton = triton_impl_c(input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_weight_quant_transposed_rhs( + input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_impl_c, + triton_fp8_blockwise_weight_quant_transposed_rhs, input_tensor, block_size, ) @@ -240,9 +219,9 @@ def verify_outputs( verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 - bytes_per_scale_el = 4 # float32 + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( From 8525822f80165d05f5d0b7cd60ae62f82ac7fe9f Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 11:02:45 -0800 Subject: [PATCH 16/19] Activation RHS fixes for PR comments --- .../bench_triton_fp8_blockwise_act_quant_rhs.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index 7595ec227f..cdd3a0df15 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -121,13 +121,12 @@ def naive_fp8_blockwise_quant( torch.float8_e4m3fn ) - # y must be column-major per RHS kernel contract + # y must be column-major per RHS kernel y = y_rowmajor.t().contiguous().t() # Reciprocal scales (row-major) -> (M_blocks, K) - reciprocal_scale = (1.0 / scale.squeeze(1) - ).to(torch.float32) # (M_blocks, K) - s = reciprocal_scale # already row-major and correct shape + reciprocal_scale = 1.0 / scale.squeeze(1) + s = reciprocal_scale return y, s From c5b058c41ee5f32896a010e8b4a6b6dabfd1e1a6 Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 11:06:33 -0800 Subject: [PATCH 17/19] LHS activation transposed fix for PR comments --- ...ton_fp8_blockwise_act_quant_transposed_lhs | 79 ++++++------------- 1 file changed, 23 insertions(+), 56 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs index 3a572c8772..2a43848a12 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs @@ -136,16 +136,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: y = y_reshaped.view(M, K).t().contiguous().to(torch.float8_e4m3fn) # Compute reciprocal scales - explicitly cast to float32 - reciprocal_scale = (1.0 / scale.squeeze(1)).to(torch.float32) + reciprocal_scale = 1.0 / scale.squeeze(1) # reciprocal_scale is (num_blocks, K), need to transpose to (K, num_blocks) - reciprocal_scale = reciprocal_scale.t().contiguous() + reciprocal_scale = reciprocal_scale.t() - # Convert to column-major using as_strided (matching Triton kernel output) - s = x.new_empty(K, num_blocks, dtype=torch.float32).as_strided( - (K, num_blocks), - (1, K), # Column-major strides - ) - s.copy_(reciprocal_scale) + s = reciprocal_scale.t().contiguous().t() return y, s @@ -164,45 +159,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: y_triton_float = y_triton.to(torch.float32) # Check quantized values are close + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations" + ) - # Check quantized values are close - try: - torch.testing.assert_close( - y_naive_float, - y_triton_float, - rtol=rtol, - atol=atol, - msg="Quantized values differ between naive and Triton implementations" - ) - except AssertionError as e: - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{y_naive_float.min():.6f}, {y_naive_float.max():.6f}]" - ) - print( - f" Triton scale range: [{y_triton_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print(f" Error details: {e}") - - try: - torch.testing.assert_close( - s_naive, - s_triton, - rtol=rtol, - atol=atol, - msg="Scales differ between naive and Triton implementations" - ) - except AssertionError as e: - max_diff = (s_naive - s_triton).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]" - ) - print( - f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]" - ) - print(f" Error details: {e}") + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations" + ) input_tensor = torch.randn( M, K, @@ -221,14 +192,10 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: block_size, ) - # Benchmark Triton implementation - triton_impl_c = torch.compile( - triton_fp8_blockwise_act_quant_transposed_lhs) - - # Benchmark after warmup - y_triton, s_triton = triton_impl_c(input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_act_quant_transposed_lhs( + input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_impl_c, + triton_fp8_blockwise_act_quant_transposed_lhs, input_tensor, block_size, ) @@ -238,9 +205,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 - bytes_per_scale_el = 4 # float32 + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( From 83af1d764938e1b11a80246e4d545d420295915e Mon Sep 17 00:00:00 2001 From: agolajko Date: Fri, 7 Nov 2025 11:12:08 -0800 Subject: [PATCH 18/19] lint --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 6 +- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 11 +-- ...h_triton_fp8_blockwise_weight_quant_rhs.py | 79 +++++++------------ ...8_blockwise_weight_quant_transposed_rhs.py | 6 +- 4 files changed, 37 insertions(+), 65 deletions(-) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index 38d139b335..71ecd70b5f 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -119,8 +119,7 @@ def verify_outputs( # Benchmark naive implementation naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) - y_naive, s_naive = naive_impl_c( - input_tensor, block_size) + y_naive, s_naive = naive_impl_c(input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( naive_impl_c, input_tensor, @@ -128,8 +127,7 @@ def verify_outputs( ) # Benchmark Triton implementation - y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs( - input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs(input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( triton_fp8_blockwise_act_quant_lhs, input_tensor, diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index cdd3a0df15..40fabc283e 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -80,7 +80,7 @@ def naive_fp8_blockwise_quant( RHS semantics: • Groups are (block_size x 1) along the M dimension (rows). - • y is returned in column-major layout (M, K) with strides (1, M). + • y is returned in column-major layout (M, K). • s has shape (ceil(M/block_size), K) in row-major (reciprocal scales). """ assert x.is_contiguous(), "Input must be contiguous" @@ -96,8 +96,7 @@ def naive_fp8_blockwise_quant( # Pad rows so we can reshape without a loop; then crop back. pad_rows = M_blocks * block_size - M if pad_rows: - x = torch.nn.functional.pad( - x, (0, 0, 0, pad_rows)) # pad rows at bottom + x = torch.nn.functional.pad(x, (0, 0, 0, pad_rows)) # pad rows at bottom # Reshape to (M_blocks, block_size, K) for block-wise operations along M x_reshaped = x.view(M_blocks, block_size, K) @@ -113,8 +112,7 @@ def naive_fp8_blockwise_quant( scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) # Quantize (still (M_blocks, block_size, K)) - y_reshaped = torch.clamp( - x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) + y_reshaped = torch.clamp(x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) # Back to (M_padded, K), then crop to (M, K) y_rowmajor = y_reshaped.view(M_blocks * block_size, K)[:M, :].to( @@ -179,8 +177,7 @@ def verify_outputs( ) # Benchmark Triton implementation - y_triton, s_triton = triton_fp8_blockwise_act_quant_rhs( - input_tensor, block_size) + y_triton, s_triton = triton_fp8_blockwise_act_quant_rhs(input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( triton_fp8_blockwise_act_quant_rhs, input_tensor, diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py index 8201f9dcd7..557402570d 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -145,21 +145,17 @@ def naive_fp8_blockwise_weight_quant( # Reshape back and convert to FP8 # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) y_reshaped = y_blocks.permute(0, 2, 1, 3) - y_rowmajor = y_reshaped.reshape(M, N).contiguous() + y_rowmajor = y_reshaped.reshape(M, N).to(torch.float8_e4m3fn) # Convert to column-major format - y = x.new_empty(M, N, dtype=torch.float8_e4m3fn).as_strided((M, N), (1, M)) - y.copy_(y_rowmajor.to(torch.float8_e4m3fn)) + + y = y_rowmajor.t().contiguous().t() # Compute reciprocal scales - explicitly cast to float32 - reciprocal_scale = (1.0 / scale).to(torch.float32) + reciprocal_scale = 1.0 / scale - # Convert to column-major using as_strided - s = x.new_empty(M_blocks, N_blocks, dtype=torch.float32).as_strided( - (M_blocks, N_blocks), - (1, M_blocks), # Column-major strides - ) - s.copy_(reciprocal_scale) + # Convert to column-major + s = reciprocal_scale.t().contiguous().t() return y, s @@ -179,39 +175,21 @@ def verify_outputs( y_triton_float = y_triton.to(torch.float32) # Check quantized values are close - try: - torch.testing.assert_close( - y_naive_float, - y_triton_float, - rtol=rtol, - atol=atol, - msg="Quantized values differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (y_naive_float - y_triton_float).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print( - f" Naive scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print( - f" Triton scale range: [{y_naive_float.min():.6f}, {y_triton_float.max():.6f}]" - ) - print(f" Error details: {e}") - - try: - torch.testing.assert_close( - s_naive, - s_triton, - rtol=rtol, - atol=atol, - msg="Scales differ between naive and Triton implementations", - ) - except AssertionError as e: - max_diff = (s_naive - s_triton).abs().max().item() - print(f"WARNING: Scales differ! Max diff: {max_diff}") - print(f" Naive scale range: [{s_naive.min():.6f}, {s_naive.max():.6f}]") - print(f" Triton scale range: [{s_triton.min():.6f}, {s_triton.max():.6f}]") - print(f" Error details: {e}") + torch.testing.assert_close( + y_naive_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between naive and Triton implementations", + ) + + torch.testing.assert_close( + s_naive, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between naive and Triton implementations", + ) # Create input tensor input_tensor = torch.randn( @@ -221,7 +199,7 @@ def verify_outputs( device=device, ) - # Benchmark naive implementation (torch.compile handles warmup) + # Benchmark naive implementation naive_impl_c = torch.compile(naive_fp8_blockwise_weight_quant) y_naive, s_naive = naive_impl_c(input_tensor, block_size) naive_time_us = benchmark_cuda_function_in_microseconds( @@ -230,11 +208,10 @@ def verify_outputs( block_size, ) - # Benchmark Triton implementation (torch.compile handles warmup) - triton_impl_c = torch.compile(triton_fp8_blockwise_weight_quant_rhs) - y_triton, s_triton = triton_impl_c(input_tensor, block_size) + # Benchmark Triton implementation + y_triton, s_triton = triton_fp8_blockwise_weight_quant_rhs(input_tensor, block_size) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_impl_c, + triton_fp8_blockwise_weight_quant_rhs, input_tensor, block_size, ) @@ -243,9 +220,9 @@ def verify_outputs( verify_outputs(y_naive, s_naive, y_triton, s_triton) # Memory bandwidth calculations - bytes_per_input_el = torch.finfo(torch.float32).bits / 8 - bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 - bytes_per_scale_el = 4 # float32 + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py index 1c30177b50..7421e0f7d0 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -140,8 +140,7 @@ def naive_fp8_blockwise_weight_quant_transposed( # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) y_reshaped = y_blocks.permute(0, 2, 1, 3) # Reshape to (M, N) then transpose to (N, M) - y_rowmajor = y_reshaped.reshape( - M, N).t() + y_rowmajor = y_reshaped.reshape(M, N).t() # Convert to FP8 and create column-major output (matching Triton kernel) y = y_rowmajor.t().contiguous().t() @@ -208,7 +207,8 @@ def verify_outputs( # Benchmark Triton implementation (torch.compile handles warmup) y_triton, s_triton = triton_fp8_blockwise_weight_quant_transposed_rhs( - input_tensor, block_size) + input_tensor, block_size + ) triton_time_us = benchmark_cuda_function_in_microseconds( triton_fp8_blockwise_weight_quant_transposed_rhs, input_tensor, From a22d36f86c12af6d8ef54d662653c1d114ff356e Mon Sep 17 00:00:00 2001 From: agolajko Date: Sat, 8 Nov 2025 17:00:27 -0800 Subject: [PATCH 19/19] Fixes to comments --- ...ench_triton_fp8_blockwise_act_quant_lhs.py | 66 ++++++++----- ...ench_triton_fp8_blockwise_act_quant_rhs.py | 72 +++++++++------ ...fp8_blockwise_act_quant_transposed_lhs.py} | 92 +++++++++++-------- ...h_triton_fp8_blockwise_weight_quant_rhs.py | 72 +++++++++------ ...8_blockwise_weight_quant_transposed_rhs.py | 72 +++++++++------ 5 files changed, 230 insertions(+), 144 deletions(-) rename benchmarks/prototype/blockwise_fp8_training/{bench_triton_fp8_blockwise_act_quant_transposed_lhs => bench_triton_fp8_blockwise_act_quant_transposed_lhs.py} (74%) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py index 71ecd70b5f..ed80a358d1 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -33,10 +33,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: # time - naive_us: float + torch_us: float triton_us: float # mem bw - naive_gbps: float + torch_gbps: float triton_gbps: float @@ -58,6 +58,10 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), ] configs = [] @@ -79,35 +83,49 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: block_size = config.block_size def verify_outputs( - y_naive: torch.Tensor, - s_naive: torch.Tensor, + y_torch: torch.Tensor, + s_torch: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, rtol: float = 1e-2, atol: float = 1e-2, ): - """Verify that Triton and naive implementations produce similar results.""" + """Verify that Triton and torch implementations produce similar results.""" # Convert FP8 back to float for comparison - y_naive_float = y_naive.to(torch.float32) + y_torch_float = y_torch.to(torch.float32) y_triton_float = y_triton.to(torch.float32) + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + # Check quantized values are close torch.testing.assert_close( - y_naive_float, + y_torch_float, y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations", + msg="Quantized values differ between torch and Triton implementations", ) torch.testing.assert_close( - s_naive, + s_torch, s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations", + msg="Scales differ between torch and Triton implementations", ) input_tensor = torch.randn( @@ -117,11 +135,11 @@ def verify_outputs( device=device, ) - # Benchmark naive implementation - naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) - y_naive, s_naive = naive_impl_c(input_tensor, block_size) - naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, input_tensor, block_size, ) @@ -135,7 +153,7 @@ def verify_outputs( ) # Verify correctness (optional, can comment out for pure benchmarking) - verify_outputs(y_naive, s_naive, y_triton, s_triton) + verify_outputs(y_torch, s_torch, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 @@ -147,13 +165,13 @@ def verify_outputs( y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el ) - naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) return ExperimentResult( - naive_us=naive_time_us, + torch_us=torch_time_us, triton_us=triton_time_us, - naive_gbps=naive_gbps, + torch_gbps=torch_gbps, triton_gbps=triton_gbps, ) @@ -162,23 +180,23 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape (M, K)", "block_size", - "naive_us", + "torch_us", "triton_us", "speedup", - "naive_gbps", + "torch_gbps", "triton_gbps", ] rows = [] for experiment in experiments: - speedup = experiment.result.naive_us / experiment.result.triton_us + speedup = experiment.result.torch_us / experiment.result.triton_us rows.append( [ f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", experiment.config.block_size, - f"{experiment.result.naive_us:.2f}", + f"{experiment.result.torch_us:.2f}", f"{experiment.result.triton_us:.2f}", f"{speedup:.2f}x", - f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.torch_gbps:.1f}", f"{experiment.result.triton_gbps:.1f}", ] ) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py index 40fabc283e..72f20e28fa 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -28,10 +28,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: # time - naive_us: float + torch_us: float triton_us: float # mem bw - naive_gbps: float + torch_gbps: float triton_gbps: float @@ -52,6 +52,10 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), ] configs = [] @@ -72,11 +76,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: M, K = config.input_shape block_size = config.block_size - def naive_fp8_blockwise_quant( + def torch_fp8_blockwise_quant( x: torch.Tensor, block_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Naive PyTorch reference implementation for RHS blockwise FP8 quantization. + Torch reference implementation for RHS blockwise FP8 quantization. RHS semantics: • Groups are (block_size x 1) along the M dimension (rows). @@ -129,33 +133,47 @@ def naive_fp8_blockwise_quant( return y, s def verify_outputs( - y_naive: torch.Tensor, - s_naive: torch.Tensor, + y_torch: torch.Tensor, + s_torch: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, rtol: float = 1e-2, atol: float = 1e-2, ): - """Verify that Triton and naive implementations produce similar results.""" + """Verify that Triton and torch implementations produce similar results.""" # Quantized tensors (both are column-major; convert to float to compare) - y_naive_float = y_naive.to(torch.float32) + y_torch_float = y_torch.to(torch.float32) y_triton_float = y_triton.to(torch.float32) + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + torch.testing.assert_close( - y_naive_float, + y_torch_float, y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations", + msg="Quantized values differ between torch and Triton implementations", ) torch.testing.assert_close( - s_naive, + s_torch, s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations", + msg="Scales differ between torch and Triton implementations", ) input_tensor = torch.randn( @@ -166,12 +184,12 @@ def verify_outputs( ) # Compile once - naive_impl_c = torch.compile(naive_fp8_blockwise_quant) + torch_impl_c = torch.compile(torch_fp8_blockwise_quant) - # Benchmark naive implementation - y_naive, s_naive = naive_impl_c(input_tensor, block_size) - naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + # Benchmark torch implementation + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, input_tensor, block_size, ) @@ -184,8 +202,8 @@ def verify_outputs( block_size, ) - # Verify correctness (compare to naive) - verify_outputs(y_naive, s_naive, y_triton, s_triton) + # Verify correctness (compare to torch) + verify_outputs(y_torch, s_torch, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 @@ -197,13 +215,13 @@ def verify_outputs( y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el ) - naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) return ExperimentResult( - naive_us=naive_time_us, + torch_us=torch_time_us, triton_us=triton_time_us, - naive_gbps=naive_gbps, + torch_gbps=torch_gbps, triton_gbps=triton_gbps, ) @@ -212,23 +230,23 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape (M, K)", "block_size", - "naive_us", + "torch_us", "triton_us", "speedup", - "naive_gbps", + "torch_gbps", "triton_gbps", ] rows = [] for experiment in experiments: - speedup = experiment.result.naive_us / experiment.result.triton_us + speedup = experiment.result.torch_us / experiment.result.triton_us rows.append( [ f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", experiment.config.block_size, - f"{experiment.result.naive_us:.2f}", + f"{experiment.result.torch_us:.2f}", f"{experiment.result.triton_us:.2f}", f"{speedup:.2f}x", - f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.torch_gbps:.1f}", f"{experiment.result.triton_gbps:.1f}", ] ) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py similarity index 74% rename from benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs rename to benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py index 2a43848a12..c5496a97db 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py @@ -32,10 +32,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: # time - naive_us: float + torch_us: float triton_us: float # mem bw - naive_gbps: float + torch_gbps: float triton_gbps: float @@ -60,7 +60,10 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), - + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), ] configs = [] @@ -83,11 +86,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: M, K = config.input_shape block_size = config.block_size - def naive_fp8_blockwise_quant_transposed( + def torch_fp8_blockwise_quant_transposed( x: torch.Tensor, block_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Naive PyTorch reference implementation for blockwise FP8 quantization with transpose. + Torch reference implementation for blockwise FP8 quantization with transpose. This version: 1. Computes column-wise scales (along dimension 0) @@ -119,9 +122,7 @@ def naive_fp8_blockwise_quant_transposed( # Compute max absolute value per block along dimension 1 (block_size) # Result shape: (num_blocks, K) amax = torch.clamp( - x_reshaped.abs().amax(dim=1).to(torch.float64), - min=eps, - max=float('inf') + x_reshaped.abs().amax(dim=1).to(torch.float64), min=eps, max=float("inf") ) # Compute scales (num_blocks, K) -> (num_blocks, 1, K) for broadcasting @@ -129,8 +130,7 @@ def naive_fp8_blockwise_quant_transposed( # Quantize y_reshaped = x_reshaped * scale - y_reshaped = torch.clamp( - y_reshaped, min=min_fp8_e4m3, max=max_fp8_e4m3) + y_reshaped = torch.clamp(y_reshaped, min=min_fp8_e4m3, max=max_fp8_e4m3) # Reshape back to (M, K) then transpose to (K, M) y = y_reshaped.view(M, K).t().contiguous().to(torch.float8_e4m3fn) @@ -145,55 +145,71 @@ def naive_fp8_blockwise_quant_transposed( return y, s def verify_outputs( - y_naive: torch.Tensor, - s_naive: torch.Tensor, + y_torch: torch.Tensor, + s_torch: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, rtol: float = 1e-2, atol: float = 1e-2, ): - """Verify that Triton and naive implementations produce similar results.""" + """Verify that Triton and torch implementations produce similar results.""" # Convert FP8 back to float for comparison - y_naive_float = y_naive.to(torch.float32) + y_torch_float = y_torch.to(torch.float32) y_triton_float = y_triton.to(torch.float32) + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + # Check quantized values are close torch.testing.assert_close( - y_naive_float, + y_torch_float, y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations" + msg="Quantized values differ between torch and Triton implementations", ) torch.testing.assert_close( - s_naive, + s_torch, s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations" + msg="Scales differ between torch and Triton implementations", ) input_tensor = torch.randn( - M, K, + M, + K, dtype=torch.bfloat16, device=device, ) - # Benchmark naive implementation - naive_impl_c = torch.compile(naive_fp8_blockwise_quant_transposed) + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_fp8_blockwise_quant_transposed) # Benchmark after warmup - y_naive, s_naive = naive_impl_c(input_tensor, block_size) - naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, input_tensor, block_size, ) y_triton, s_triton = triton_fp8_blockwise_act_quant_transposed_lhs( - input_tensor, block_size) + input_tensor, block_size + ) triton_time_us = benchmark_cuda_function_in_microseconds( triton_fp8_blockwise_act_quant_transposed_lhs, input_tensor, @@ -201,8 +217,7 @@ def verify_outputs( ) # Verify correctness - verify_outputs(y_naive, s_naive, y_triton, - s_triton) + verify_outputs(y_torch, s_torch, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 @@ -211,17 +226,16 @@ def verify_outputs( read_bytes = input_tensor.numel() * bytes_per_input_el write_bytes = ( - y_triton.numel() * bytes_per_output_el - + s_triton.numel() * bytes_per_scale_el + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el ) - naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) return ExperimentResult( - naive_us=naive_time_us, + torch_us=torch_time_us, triton_us=triton_time_us, - naive_gbps=naive_gbps, + torch_gbps=torch_gbps, triton_gbps=triton_gbps, ) @@ -230,23 +244,23 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape (M, K)", "block_size", - "naive_us", + "torch_us", "triton_us", "speedup", - "naive_gbps", + "torch_gbps", "triton_gbps", ] rows = [] for experiment in experiments: - speedup = experiment.result.naive_us / experiment.result.triton_us + speedup = experiment.result.torch_us / experiment.result.triton_us rows.append( [ f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", experiment.config.block_size, - f"{experiment.result.naive_us:.2f}", + f"{experiment.result.torch_us:.2f}", f"{experiment.result.triton_us:.2f}", f"{speedup:.2f}x", - f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.torch_gbps:.1f}", f"{experiment.result.triton_gbps:.1f}", ] ) @@ -264,9 +278,9 @@ def main(): result = run_experiment(config) results.append(Experiment(config=config, result=result)) - print("\n" + "="*80) + print("\n" + "=" * 80) print("BENCHMARK RESULTS - Transposed LHS Quantization") - print("="*80 + "\n") + print("=" * 80 + "\n") print_results(results) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py index 557402570d..5c28a5c04f 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -32,10 +32,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: # time - naive_us: float + torch_us: float triton_us: float # mem bw - naive_gbps: float + torch_gbps: float triton_gbps: float @@ -60,6 +60,10 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), ] configs = [] @@ -80,16 +84,16 @@ def get_configs() -> List[ExperimentConfig]: def run_experiment(config: ExperimentConfig) -> ExperimentResult: """ - Run benchmark experiment comparing naive and Triton implementations. + Run benchmark experiment comparing torch and Triton implementations. """ M, N = config.input_shape block_size = config.block_size - def naive_fp8_blockwise_weight_quant( + def torch_fp8_blockwise_weight_quant( x: torch.Tensor, block_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Naive PyTorch reference implementation for blockwise FP8 weight quantization. + Torch reference implementation for blockwise FP8 weight quantization. Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. Outputs in column-major format for RHS operator. @@ -160,35 +164,49 @@ def naive_fp8_blockwise_weight_quant( return y, s def verify_outputs( - y_naive: torch.Tensor, - s_naive: torch.Tensor, + y_torch: torch.Tensor, + s_torch: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, rtol: float = 1e-2, atol: float = 1e-2, ): - """Verify that Triton and naive implementations produce similar results.""" + """Verify that Triton and torch implementations produce similar results.""" # Convert FP8 back to float for comparison - y_naive_float = y_naive.to(torch.float32) + y_torch_float = y_torch.to(torch.float32) y_triton_float = y_triton.to(torch.float32) + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + # Check quantized values are close torch.testing.assert_close( - y_naive_float, + y_torch_float, y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations", + msg="Quantized values differ between torch and Triton implementations", ) torch.testing.assert_close( - s_naive, + s_torch, s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations", + msg="Scales differ between torch and Triton implementations", ) # Create input tensor @@ -199,11 +217,11 @@ def verify_outputs( device=device, ) - # Benchmark naive implementation - naive_impl_c = torch.compile(naive_fp8_blockwise_weight_quant) - y_naive, s_naive = naive_impl_c(input_tensor, block_size) - naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_fp8_blockwise_weight_quant) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, input_tensor, block_size, ) @@ -217,7 +235,7 @@ def verify_outputs( ) # Verify correctness - verify_outputs(y_naive, s_naive, y_triton, s_triton) + verify_outputs(y_torch, s_torch, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 @@ -229,13 +247,13 @@ def verify_outputs( y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el ) - naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) return ExperimentResult( - naive_us=naive_time_us, + torch_us=torch_time_us, triton_us=triton_time_us, - naive_gbps=naive_gbps, + torch_gbps=torch_gbps, triton_gbps=triton_gbps, ) @@ -245,23 +263,23 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape (M, N)", "block_size", - "naive_us", + "torch_us", "triton_us", "speedup", - "naive_gbps", + "torch_gbps", "triton_gbps", ] rows = [] for experiment in experiments: - speedup = experiment.result.naive_us / experiment.result.triton_us + speedup = experiment.result.torch_us / experiment.result.triton_us rows.append( [ f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", experiment.config.block_size, - f"{experiment.result.naive_us:.2f}", + f"{experiment.result.torch_us:.2f}", f"{experiment.result.triton_us:.2f}", f"{speedup:.2f}x", - f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.torch_gbps:.1f}", f"{experiment.result.triton_gbps:.1f}", ] ) diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py index 7421e0f7d0..f0cf2cd54f 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -32,10 +32,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: # time - naive_us: float + torch_us: float triton_us: float # mem bw - naive_gbps: float + torch_gbps: float triton_gbps: float @@ -60,6 +60,10 @@ def get_configs() -> List[ExperimentConfig]: (2048, 4096), (4096, 4096), (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), ] configs = [] @@ -80,16 +84,16 @@ def get_configs() -> List[ExperimentConfig]: def run_experiment(config: ExperimentConfig) -> ExperimentResult: """ - Run benchmark experiment comparing naive and Triton implementations. + Run benchmark experiment comparing torch and Triton implementations. """ M, N = config.input_shape block_size = config.block_size - def naive_fp8_blockwise_weight_quant_transposed( + def torch_fp8_blockwise_weight_quant_transposed( x: torch.Tensor, block_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Naive PyTorch reference implementation for blockwise FP8 weight quantization with transpose. + Torch reference implementation for blockwise FP8 weight quantization with transpose. Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. Outputs transposed tensor (N, M) in column-major format for RHS operator. @@ -157,35 +161,49 @@ def naive_fp8_blockwise_weight_quant_transposed( return y, s def verify_outputs( - y_naive: torch.Tensor, - s_naive: torch.Tensor, + y_torch: torch.Tensor, + s_torch: torch.Tensor, y_triton: torch.Tensor, s_triton: torch.Tensor, rtol: float = 1e-2, atol: float = 1e-2, ): - """Verify that Triton and naive implementations produce similar results.""" + """Verify that Triton and torch implementations produce similar results.""" # Convert FP8 back to float for comparison - y_naive_float = y_naive.to(torch.float32) + y_torch_float = y_torch.to(torch.float32) y_triton_float = y_triton.to(torch.float32) + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + # Check quantized values are close torch.testing.assert_close( - y_naive_float, + y_torch_float, y_triton_float, rtol=rtol, atol=atol, - msg="Quantized values differ between naive and Triton implementations", + msg="Quantized values differ between torch and Triton implementations", ) torch.testing.assert_close( - s_naive, + s_torch, s_triton, rtol=rtol, atol=atol, - msg="Scales differ between naive and Triton implementations", + msg="Scales differ between torch and Triton implementations", ) # Create input tensor @@ -196,11 +214,11 @@ def verify_outputs( device=device, ) - # Benchmark naive implementation (torch.compile handles warmup) - naive_impl_c = torch.compile(naive_fp8_blockwise_weight_quant_transposed) - y_naive, s_naive = naive_impl_c(input_tensor, block_size) - naive_time_us = benchmark_cuda_function_in_microseconds( - naive_impl_c, + # Benchmark torch implementation (torch.compile handles warmup) + torch_impl_c = torch.compile(torch_fp8_blockwise_weight_quant_transposed) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, input_tensor, block_size, ) @@ -216,7 +234,7 @@ def verify_outputs( ) # Verify correctness - verify_outputs(y_naive, s_naive, y_triton, s_triton) + verify_outputs(y_torch, s_torch, y_triton, s_triton) # Memory bandwidth calculations bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 @@ -228,13 +246,13 @@ def verify_outputs( y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el ) - naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) return ExperimentResult( - naive_us=naive_time_us, + torch_us=torch_time_us, triton_us=triton_time_us, - naive_gbps=naive_gbps, + torch_gbps=torch_gbps, triton_gbps=triton_gbps, ) @@ -244,23 +262,23 @@ def print_results(experiments: List[Experiment]): headers = [ "input_shape (M, N)", "block_size", - "naive_us", + "torch_us", "triton_us", "speedup", - "naive_gbps", + "torch_gbps", "triton_gbps", ] rows = [] for experiment in experiments: - speedup = experiment.result.naive_us / experiment.result.triton_us + speedup = experiment.result.torch_us / experiment.result.triton_us rows.append( [ f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", experiment.config.block_size, - f"{experiment.result.naive_us:.2f}", + f"{experiment.result.torch_us:.2f}", f"{experiment.result.triton_us:.2f}", f"{speedup:.2f}x", - f"{experiment.result.naive_gbps:.1f}", + f"{experiment.result.torch_gbps:.1f}", f"{experiment.result.triton_gbps:.1f}", ] )