diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py new file mode 100644 index 0000000000..ed80a358d1 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + torch_blockwise_scale_act_quant_lhs, + triton_fp8_blockwise_act_quant_lhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + # mem bw + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + """ + # Llama-style shapes: various batch*seq_len sizes with typical hidden dims + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def verify_outputs( + y_torch: torch.Tensor, + s_torch: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and torch implementations produce similar results.""" + + # Convert FP8 back to float for comparison + y_torch_float = y_torch.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + + # Check quantized values are close + + torch.testing.assert_close( + y_torch_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between torch and Triton implementations", + ) + + torch.testing.assert_close( + s_torch, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between torch and Triton implementations", + ) + + input_tensor = torch.randn( + M, + K, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation + y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_blockwise_act_quant_lhs, + input_tensor, + block_size, + ) + + # Verify correctness (optional, can comment out for pure benchmarking) + verify_outputs(y_torch, s_torch, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_us=torch_time_us, + triton_us=triton_time_us, + torch_gbps=torch_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "torch_us", + "triton_us", + "speedup", + "torch_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.torch_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.torch_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.torch_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py new file mode 100644 index 0000000000..72f20e28fa --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py @@ -0,0 +1,274 @@ +# rhs_benchmark.py +# Copyright (c) Meta Platforms... +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_rhs, # <- RHS kernel +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + # mem bw + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + """ + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def torch_fp8_blockwise_quant( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Torch reference implementation for RHS blockwise FP8 quantization. + + RHS semantics: + • Groups are (block_size x 1) along the M dimension (rows). + • y is returned in column-major layout (M, K). + • s has shape (ceil(M/block_size), K) in row-major (reciprocal scales). + """ + assert x.is_contiguous(), "Input must be contiguous" + + M, K = x.size() + M_blocks = (M + block_size - 1) // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Pad rows so we can reshape without a loop; then crop back. + pad_rows = M_blocks * block_size - M + if pad_rows: + x = torch.nn.functional.pad(x, (0, 0, 0, pad_rows)) # pad rows at bottom + + # Reshape to (M_blocks, block_size, K) for block-wise operations along M + x_reshaped = x.view(M_blocks, block_size, K) + + # Compute max abs per column within each block -> (M_blocks, K) + amax = torch.clamp( + x_reshaped.abs().amax(dim=1).to(torch.float64), + min=eps, + max=float("inf"), + ) + + # Compute scales -> (M_blocks, 1, K) for broadcasting across rows in block + scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) + + # Quantize (still (M_blocks, block_size, K)) + y_reshaped = torch.clamp(x_reshaped * scale, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Back to (M_padded, K), then crop to (M, K) + y_rowmajor = y_reshaped.view(M_blocks * block_size, K)[:M, :].to( + torch.float8_e4m3fn + ) + + # y must be column-major per RHS kernel + y = y_rowmajor.t().contiguous().t() + + # Reciprocal scales (row-major) -> (M_blocks, K) + reciprocal_scale = 1.0 / scale.squeeze(1) + s = reciprocal_scale + + return y, s + + def verify_outputs( + y_torch: torch.Tensor, + s_torch: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and torch implementations produce similar results.""" + + # Quantized tensors (both are column-major; convert to float to compare) + y_torch_float = y_torch.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + + torch.testing.assert_close( + y_torch_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between torch and Triton implementations", + ) + + torch.testing.assert_close( + s_torch, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between torch and Triton implementations", + ) + + input_tensor = torch.randn( + M, + K, + dtype=torch.bfloat16, + device=device, + ) + + # Compile once + torch_impl_c = torch.compile(torch_fp8_blockwise_quant) + + # Benchmark torch implementation + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation + y_triton, s_triton = triton_fp8_blockwise_act_quant_rhs(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_blockwise_act_quant_rhs, + input_tensor, + block_size, + ) + + # Verify correctness (compare to torch) + verify_outputs(y_torch, s_torch, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_us=torch_time_us, + triton_us=triton_time_us, + torch_gbps=torch_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "torch_us", + "triton_us", + "speedup", + "torch_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.torch_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.torch_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.torch_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking RHS"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS (RHS)") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py new file mode 100644 index 0000000000..c5496a97db --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_transposed_lhs.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_transposed_lhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, K) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + # mem bw + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical transformer activation shapes. + Format: (batch_size * seq_len, hidden_dim) + + Note: For transposed_lhs, M must be divisible by block_size + """ + # Llama-style shapes: various batch*seq_len sizes with typical hidden dims + # Ensuring M is divisible by block_size (128) + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify M is divisible by block_size + if shape[0] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, K = config.input_shape + block_size = config.block_size + + def torch_fp8_blockwise_quant_transposed( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Torch reference implementation for blockwise FP8 quantization with transpose. + + This version: + 1. Computes column-wise scales (along dimension 0) + 2. Outputs transposed quantized tensor (K, M) in row-major format + 3. Outputs scales in shape (K, M//block_size) + + Args: + x: Input tensor of shape (M, K) + block_size: Number of elements per block along M dimension + + Returns: + y: Transposed quantized tensor in FP8, shape (K, M) in row-major + s: Reciprocal scales in column-major format (K, M//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.size(0) % block_size == 0, "M must be divisible by block_size" + + M, K = x.size() + num_blocks = M // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (num_blocks, block_size, K) for block-wise operations along M + x_reshaped = x.view(num_blocks, block_size, K) + + # Compute max absolute value per block along dimension 1 (block_size) + # Result shape: (num_blocks, K) + amax = torch.clamp( + x_reshaped.abs().amax(dim=1).to(torch.float64), min=eps, max=float("inf") + ) + + # Compute scales (num_blocks, K) -> (num_blocks, 1, K) for broadcasting + scale = (max_fp8_e4m3 / amax).to(torch.float32).unsqueeze(1) + + # Quantize + y_reshaped = x_reshaped * scale + y_reshaped = torch.clamp(y_reshaped, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Reshape back to (M, K) then transpose to (K, M) + y = y_reshaped.view(M, K).t().contiguous().to(torch.float8_e4m3fn) + + # Compute reciprocal scales - explicitly cast to float32 + reciprocal_scale = 1.0 / scale.squeeze(1) + # reciprocal_scale is (num_blocks, K), need to transpose to (K, num_blocks) + reciprocal_scale = reciprocal_scale.t() + + s = reciprocal_scale.t().contiguous().t() + + return y, s + + def verify_outputs( + y_torch: torch.Tensor, + s_torch: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and torch implementations produce similar results.""" + + # Convert FP8 back to float for comparison + y_torch_float = y_torch.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + + # Check quantized values are close + torch.testing.assert_close( + y_torch_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between torch and Triton implementations", + ) + + torch.testing.assert_close( + s_torch, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between torch and Triton implementations", + ) + + input_tensor = torch.randn( + M, + K, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_fp8_blockwise_quant_transposed) + + # Benchmark after warmup + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, + input_tensor, + block_size, + ) + + y_triton, s_triton = triton_fp8_blockwise_act_quant_transposed_lhs( + input_tensor, block_size + ) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_blockwise_act_quant_transposed_lhs, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_torch, s_torch, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_us=torch_time_us, + triton_us=triton_time_us, + torch_gbps=torch_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape (M, K)", + "block_size", + "torch_us", + "triton_us", + "speedup", + "torch_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.torch_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.torch_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.torch_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS - Transposed LHS Quantization") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py new file mode 100644 index 0000000000..5c28a5c04f --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_rhs.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_weight_quant_rhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, N) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + # mem bw + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical weight matrix shapes. + Format: (hidden_dim, hidden_dim) for square matrices or (hidden_dim_in, hidden_dim_out) + + Note: Both M and N must be divisible by block_size (128) + """ + # Common weight matrix shapes in transformers + # Format: (in_features, out_features) for weight matrices + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify both dimensions are divisible by block_size + if shape[0] % block_size == 0 and shape[1] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + """ + Run benchmark experiment comparing torch and Triton implementations. + """ + M, N = config.input_shape + block_size = config.block_size + + def torch_fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Torch reference implementation for blockwise FP8 weight quantization. + + Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. + Outputs in column-major format for RHS operator. + + Args: + x: Input tensor of shape (M, N), row-major + block_size: Block size for quantization + + Returns: + y: Quantized tensor in FP8, shape (M, N), column-major format + s: Reciprocal scales in column-major format (M//block_size, N//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.dim() == 2, "Input must be 2D" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + "Both dimensions must be divisible by block_size" + ) + + M, N = x.size() + M_blocks = M // block_size + N_blocks = N // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (M_blocks, block_size, N_blocks, block_size) for block-wise operations + x_reshaped = x.view(M_blocks, block_size, N_blocks, block_size) + # Permute to (M_blocks, N_blocks, block_size, block_size) for easier block processing + x_blocks = x_reshaped.permute(0, 2, 1, 3) + + # Compute max absolute value per block (M_blocks, N_blocks) + amax = torch.clamp( + x_blocks.reshape(M_blocks, N_blocks, -1) + .abs() + .amax(dim=2) + .to(torch.float64), + min=eps, + max=float("inf"), + ) + + # Compute scales (M_blocks, N_blocks) + scale = (max_fp8_e4m3 / amax).to(torch.float32) + + # Broadcast scale for quantization (M_blocks, N_blocks, 1, 1) + scale_broadcast = scale.unsqueeze(2).unsqueeze(3) + + # Quantize + y_blocks = x_blocks * scale_broadcast + y_blocks = torch.clamp(y_blocks, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Reshape back and convert to FP8 + # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) + y_reshaped = y_blocks.permute(0, 2, 1, 3) + y_rowmajor = y_reshaped.reshape(M, N).to(torch.float8_e4m3fn) + + # Convert to column-major format + + y = y_rowmajor.t().contiguous().t() + + # Compute reciprocal scales - explicitly cast to float32 + reciprocal_scale = 1.0 / scale + + # Convert to column-major + s = reciprocal_scale.t().contiguous().t() + + return y, s + + def verify_outputs( + y_torch: torch.Tensor, + s_torch: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and torch implementations produce similar results.""" + + # Convert FP8 back to float for comparison + + y_torch_float = y_torch.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + + # Check quantized values are close + torch.testing.assert_close( + y_torch_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between torch and Triton implementations", + ) + + torch.testing.assert_close( + s_torch, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between torch and Triton implementations", + ) + + # Create input tensor + input_tensor = torch.randn( + M, + N, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark torch implementation + torch_impl_c = torch.compile(torch_fp8_blockwise_weight_quant) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation + y_triton, s_triton = triton_fp8_blockwise_weight_quant_rhs(input_tensor, block_size) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_blockwise_weight_quant_rhs, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_torch, s_torch, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_us=torch_time_us, + triton_us=triton_time_us, + torch_gbps=torch_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + """Print benchmark results in a formatted table.""" + headers = [ + "input_shape (M, N)", + "block_size", + "torch_us", + "triton_us", + "speedup", + "torch_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.torch_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.torch_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.torch_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + """Main benchmark execution.""" + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS - RHS Weight Quantization") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py new file mode 100644 index 0000000000..f0cf2cd54f --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_weight_quant_transposed_rhs.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from tabulate import tabulate +from tqdm import tqdm + +# Assuming these imports based on the kernel location +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_weight_quant_transposed_rhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: Tuple[int, int] # (M, N) + block_size: int + + +@dataclass(frozen=True) +class ExperimentResult: + # time + torch_us: float + triton_us: float + # mem bw + torch_gbps: float + triton_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + """ + Test configurations for typical weight matrix shapes. + Format: (hidden_dim, hidden_dim) for square matrices or (hidden_dim_in, hidden_dim_out) + + Note: Both M and N must be divisible by block_size (128) + """ + # Common weight matrix shapes in transformers + # Format: (in_features, out_features) for weight matrices + input_shapes = [ + (512, 4096), + (1024, 4096), + (2048, 4096), + (4096, 4096), + (8192, 4096), + (16384, 4096), + (32768, 4096), + (65536, 4096), + (131_072, 4096), + ] + + configs = [] + block_sizes = [128] # Standard block size for FP8 + + for shape in input_shapes: + for block_size in block_sizes: + # Verify both dimensions are divisible by block_size + if shape[0] % block_size == 0 and shape[1] % block_size == 0: + configs.append( + ExperimentConfig( + input_shape=shape, + block_size=block_size, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + """ + Run benchmark experiment comparing torch and Triton implementations. + """ + M, N = config.input_shape + block_size = config.block_size + + def torch_fp8_blockwise_weight_quant_transposed( + x: torch.Tensor, block_size: int = 128 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Torch reference implementation for blockwise FP8 weight quantization with transpose. + + Quantizes in (block_size x block_size) blocks. Each block gets one scale factor. + Outputs transposed tensor (N, M) in column-major format for RHS operator. + + Args: + x: Input tensor of shape (M, N), row-major + block_size: Block size for quantization + + Returns: + y: Transposed quantized tensor in FP8, shape (N, M), column-major format + s: Reciprocal scales in column-major format (N//block_size, M//block_size) + """ + assert x.is_contiguous(), "Input must be contiguous" + assert x.dim() == 2, "Input must be 2D" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + "Both dimensions must be divisible by block_size" + ) + + M, N = x.size() + M_blocks = M // block_size + N_blocks = N // block_size + + # FP8 E4M3 constants + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + eps = 1e-12 + + # Reshape to (M_blocks, block_size, N_blocks, block_size) + x_reshaped = x.view(M_blocks, block_size, N_blocks, block_size) + # Permute to (M_blocks, N_blocks, block_size, block_size) for easier block processing + x_blocks = x_reshaped.permute(0, 2, 1, 3) + + # Compute max absolute value per block (M_blocks, N_blocks) + amax = torch.clamp( + x_blocks.abs().amax(dim=(2, 3)).to(torch.float64), min=eps, max=float("inf") + ) + + # Compute scales (M_blocks, N_blocks) + scale = (max_fp8_e4m3 / amax).to(torch.float32) + + # Broadcast scale for quantization (M_blocks, N_blocks, 1, 1) + scale_broadcast = scale[:, :, None, None] + + # Quantize + y_blocks = x_blocks * scale_broadcast + y_blocks = torch.clamp(y_blocks, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Permute back: (M_blocks, N_blocks, block_size, block_size) -> (M_blocks, block_size, N_blocks, block_size) + y_reshaped = y_blocks.permute(0, 2, 1, 3) + # Reshape to (M, N) then transpose to (N, M) + y_rowmajor = y_reshaped.reshape(M, N).t() + + # Convert to FP8 and create column-major output (matching Triton kernel) + y = y_rowmajor.t().contiguous().t() + y = y_rowmajor.to(torch.float8_e4m3fn) + + # Compute reciprocal scales + reciprocal_scale = 1.0 / scale + # Transpose scale matrix to match output dimensions: (M_blocks, N_blocks) -> (N_blocks, M_blocks) + reciprocal_scale = reciprocal_scale.t() + + # Convert to col-major + s = reciprocal_scale.t().contiguous().t() + + return y, s + + def verify_outputs( + y_torch: torch.Tensor, + s_torch: torch.Tensor, + y_triton: torch.Tensor, + s_triton: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + ): + """Verify that Triton and torch implementations produce similar results.""" + + # Convert FP8 back to float for comparison + + y_torch_float = y_torch.to(torch.float32) + y_triton_float = y_triton.to(torch.float32) + + assert y_torch.shape == y_triton.shape, ( + f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" + ) + assert y_torch.stride() == y_triton.stride(), ( + f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" + ) + + assert s_torch.shape == s_triton.shape, ( + f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" + ) + assert s_torch.stride() == s_triton.stride(), ( + f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" + ) + + # Check quantized values are close + torch.testing.assert_close( + y_torch_float, + y_triton_float, + rtol=rtol, + atol=atol, + msg="Quantized values differ between torch and Triton implementations", + ) + + torch.testing.assert_close( + s_torch, + s_triton, + rtol=rtol, + atol=atol, + msg="Scales differ between torch and Triton implementations", + ) + + # Create input tensor + input_tensor = torch.randn( + M, + N, + dtype=torch.bfloat16, + device=device, + ) + + # Benchmark torch implementation (torch.compile handles warmup) + torch_impl_c = torch.compile(torch_fp8_blockwise_weight_quant_transposed) + y_torch, s_torch = torch_impl_c(input_tensor, block_size) + torch_time_us = benchmark_cuda_function_in_microseconds( + torch_impl_c, + input_tensor, + block_size, + ) + + # Benchmark Triton implementation (torch.compile handles warmup) + y_triton, s_triton = triton_fp8_blockwise_weight_quant_transposed_rhs( + input_tensor, block_size + ) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_blockwise_weight_quant_transposed_rhs, + input_tensor, + block_size, + ) + + # Verify correctness + verify_outputs(y_torch, s_torch, y_triton, s_triton) + + # Memory bandwidth calculations + bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 + bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 + bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el + ) + + torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_us=torch_time_us, + triton_us=triton_time_us, + torch_gbps=torch_gbps, + triton_gbps=triton_gbps, + ) + + +def print_results(experiments: List[Experiment]): + """Print benchmark results in a formatted table.""" + headers = [ + "input_shape (M, N)", + "block_size", + "torch_us", + "triton_us", + "speedup", + "torch_gbps", + "triton_gbps", + ] + rows = [] + for experiment in experiments: + speedup = experiment.result.torch_us / experiment.result.triton_us + rows.append( + [ + f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", + experiment.config.block_size, + f"{experiment.result.torch_us:.2f}", + f"{experiment.result.triton_us:.2f}", + f"{speedup:.2f}x", + f"{experiment.result.torch_gbps:.1f}", + f"{experiment.result.triton_gbps:.1f}", + ] + ) + print(tabulate(rows, headers=headers, tablefmt="grid")) + + +def main(): + """Main benchmark execution.""" + torch.random.manual_seed(123) + configs = get_configs() + results = [] + + print(f"Running {len(configs)} benchmark configurations...\n") + + for config in tqdm(configs, desc="Benchmarking"): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + print("\n" + "=" * 80) + print("BENCHMARK RESULTS - Transposed RHS Weight Quantization") + print("=" * 80 + "\n") + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index 3f82407d40..2ceb839173 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -116,10 +116,13 @@ def triton_fp8_gemm_1x128_128x128( K = a.size(1) N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x128_kernel)[grid]( a, a.stride(0), @@ -231,10 +234,13 @@ def triton_fp8_gemm_1x128_128x1( K = a.size(1) N = b.size(1) c = a.new_empty(M, N, dtype=out_dtype) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x1_kernel)[grid]( a, a.stride(0), @@ -350,10 +356,13 @@ def triton_fp8_blockwise_act_quant_lhs( (M, K // block_size), (1, M), ) - grid = lambda meta: ( - triton.cdiv(M, meta["NUM_GROUPS"]), - triton.cdiv(K, meta["BLOCK_SIZE"]), - ) + + def grid(meta): + return ( + triton.cdiv(M, meta["NUM_GROUPS"]), + triton.cdiv(K, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_lhs_kernel)[grid]( x, x.stride(0), @@ -443,10 +452,12 @@ def triton_fp8_blockwise_act_quant_rhs( y = y.as_strided(y.size(), (1, y.size(0))) s = x.new_empty(M_blocks, K, dtype=torch.float32) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(K, meta["NUM_GROUPS"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_rhs_kernel)[grid]( x, x.stride(0), @@ -549,10 +560,12 @@ def triton_fp8_blockwise_act_quant_transposed_lhs( (K, M_blocks), # shape (1, K), # stride ) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(K, meta["NUM_GROUPS"]), - ) + + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) wrap_triton(triton_fp8_blockwise_act_quant_transposed_lhs_kernel)[grid]( x, @@ -639,10 +652,13 @@ def triton_fp8_blockwise_weight_quant_rhs( (M_blocks, N_blocks), # shape (1, M_blocks), # stride ) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) + + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_rhs_kernel)[grid]( x, x.stride(0), @@ -744,10 +760,13 @@ def triton_fp8_blockwise_weight_quant_transposed_rhs( (n_blocks, m_blocks), # shape (1, n_blocks), # stride ) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) + + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_transposed_rhs_kernel)[grid]( x, x.stride(0), @@ -794,6 +813,7 @@ def torch_blockwise_scale_act_quant_lhs(x, tile_size=128): # Reshape quantized output back to original shape and reshape scales accordingly x = x.reshape(*orig_shape) s = s.reshape(orig_shape[0], -1).to(torch.float) + s = s.transpose(-2, -1).contiguous().transpose(-2, -1) # Return output tensor and reciprocal scale return x, 1.0 / s