Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Tuple

import torch
from tabulate import tabulate
from tqdm import tqdm

# Assuming these imports based on the kernel location
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.blockwise_fp8_training.kernels import (
torch_blockwise_scale_act_quant_lhs,
triton_fp8_blockwise_act_quant_lhs,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: Tuple[int, int] # (M, K)
block_size: int


@dataclass(frozen=True)
class ExperimentResult:
# time
torch_us: float
triton_us: float
# mem bw
torch_gbps: float
triton_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
"""
Test configurations for typical transformer activation shapes.
Format: (batch_size * seq_len, hidden_dim)
"""
# Llama-style shapes: various batch*seq_len sizes with typical hidden dims
input_shapes = [
(512, 4096),
(1024, 4096),
(2048, 4096),
(4096, 4096),
(8192, 4096),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the leading total_M dims (seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.

same for act_quant_rhs benchmarks

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, any downside to having all the above quantization benchmarks with these bigger values?

(16384, 4096),
(32768, 4096),
(65536, 4096),
(131_072, 4096),
]

configs = []
block_sizes = [128] # Standard block size for FP8

for shape in input_shapes:
for block_size in block_sizes:
configs.append(
ExperimentConfig(
input_shape=shape,
block_size=block_size,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
M, K = config.input_shape
block_size = config.block_size

def verify_outputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in these various verify_outputs can we also validate the memory layouts are the same? i.e., check shapes and strides match.

this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance)

y_torch: torch.Tensor,
s_torch: torch.Tensor,
y_triton: torch.Tensor,
s_triton: torch.Tensor,
rtol: float = 1e-2,
atol: float = 1e-2,
):
"""Verify that Triton and torch implementations produce similar results."""

# Convert FP8 back to float for comparison
y_torch_float = y_torch.to(torch.float32)
y_triton_float = y_triton.to(torch.float32)

assert y_torch.shape == y_triton.shape, (
f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}"
)
assert y_torch.stride() == y_triton.stride(), (
f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}"
)

assert s_torch.shape == s_triton.shape, (
f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}"
)
assert s_torch.stride() == s_triton.stride(), (
f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}"
)

# Check quantized values are close

torch.testing.assert_close(
y_torch_float,
y_triton_float,
rtol=rtol,
atol=atol,
msg="Quantized values differ between torch and Triton implementations",
)

torch.testing.assert_close(
s_torch,
s_triton,
rtol=rtol,
atol=atol,
msg="Scales differ between torch and Triton implementations",
)

input_tensor = torch.randn(
M,
K,
dtype=torch.bfloat16,
device=device,
)

# Benchmark torch implementation
torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
y_torch, s_torch = torch_impl_c(input_tensor, block_size)
torch_time_us = benchmark_cuda_function_in_microseconds(
torch_impl_c,
input_tensor,
block_size,
)

# Benchmark Triton implementation
y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs(input_tensor, block_size)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_fp8_blockwise_act_quant_lhs,
input_tensor,
block_size,
)

# Verify correctness (optional, can comment out for pure benchmarking)
verify_outputs(y_torch, s_torch, y_triton, s_triton)

# Memory bandwidth calculations
bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8
bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8
bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = (
y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el
)

torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
torch_us=torch_time_us,
triton_us=triton_time_us,
torch_gbps=torch_gbps,
triton_gbps=triton_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape (M, K)",
"block_size",
"torch_us",
"triton_us",
"speedup",
"torch_gbps",
"triton_gbps",
]
rows = []
for experiment in experiments:
speedup = experiment.result.torch_us / experiment.result.triton_us
rows.append(
[
f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
experiment.config.block_size,
f"{experiment.result.torch_us:.2f}",
f"{experiment.result.triton_us:.2f}",
f"{speedup:.2f}x",
f"{experiment.result.torch_gbps:.1f}",
f"{experiment.result.triton_gbps:.1f}",
]
)
print(tabulate(rows, headers=headers, tablefmt="grid"))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []

print(f"Running {len(configs)} benchmark configurations...\n")

for config in tqdm(configs, desc="Benchmarking"):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

print("\n" + "=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80 + "\n")
print_results(results)


if __name__ == "__main__":
main()
Loading
Loading