-
Notifications
You must be signed in to change notification settings - Fork 364
Re: #3290 FP8 Blockwise Training Tracker, quantization benchmarks #3306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
agolajko
wants to merge
19
commits into
pytorch:main
Choose a base branch
from
agolajko:feat/fp8-quant-bench-3290
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,449
−28
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
577a570
benchmark for triton_fp8_blockwise_act_quant_transposed_lhs against n…
agolajko e5c8601
benchmark for triton_fp8_blockwise_act_quant_lhs against naive implem…
agolajko ac3b550
benchmark for triton_fp8_blockwise_act_quant_rhs against naive implem…
agolajko ee3a26e
bench for triton_fp8_blockwise_weight_quant_rhs against naive torch i…
agolajko 0bc1597
bench for triton_fp8_blockwise_weight_quant_transposed_rhs against na…
agolajko a36bb48
removed extra space from file name
agolajko 0b8b05e
Flipped mem layout of scales to streamline the LHS activation quantiz…
agolajko 066b346
updates the LHS act
agolajko 4ad066a
output bytes calc corrected
agolajko e464ad5
minor changes to RHS activation bench
agolajko 36f34ca
changes to testing
agolajko 278cb70
forgot to lint
agolajko 873ba81
Act LHS minor changes re PR comments
agolajko 8281e7b
Act RHS changes re PR comments
agolajko 2175611
Weight transposed RHS fixes for pr comments
agolajko 8525822
Activation RHS fixes for PR comments
agolajko c5b058c
LHS activation transposed fix for PR comments
agolajko 83af1d7
lint
agolajko a22d36f
Fixes to comments
agolajko File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
224 changes: 224 additions & 0 deletions
224
benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,224 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Tuple | ||
|
|
||
| import torch | ||
| from tabulate import tabulate | ||
| from tqdm import tqdm | ||
|
|
||
| # Assuming these imports based on the kernel location | ||
| from benchmarks.utils import benchmark_cuda_function_in_microseconds | ||
| from torchao.prototype.blockwise_fp8_training.kernels import ( | ||
| torch_blockwise_scale_act_quant_lhs, | ||
| triton_fp8_blockwise_act_quant_lhs, | ||
| ) | ||
|
|
||
| device = torch.device("cuda") | ||
|
|
||
| # Needed since changing args to function causes recompiles | ||
| torch._dynamo.config.cache_size_limit = 1000 | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ExperimentConfig: | ||
| input_shape: Tuple[int, int] # (M, K) | ||
| block_size: int | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ExperimentResult: | ||
| # time | ||
| torch_us: float | ||
| triton_us: float | ||
| # mem bw | ||
| torch_gbps: float | ||
| triton_gbps: float | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class Experiment: | ||
| config: ExperimentConfig | ||
| result: ExperimentResult | ||
|
|
||
|
|
||
| def get_configs() -> List[ExperimentConfig]: | ||
| """ | ||
| Test configurations for typical transformer activation shapes. | ||
| Format: (batch_size * seq_len, hidden_dim) | ||
| """ | ||
| # Llama-style shapes: various batch*seq_len sizes with typical hidden dims | ||
| input_shapes = [ | ||
| (512, 4096), | ||
| (1024, 4096), | ||
| (2048, 4096), | ||
| (4096, 4096), | ||
| (8192, 4096), | ||
| (16384, 4096), | ||
| (32768, 4096), | ||
| (65536, 4096), | ||
| (131_072, 4096), | ||
| ] | ||
|
|
||
| configs = [] | ||
| block_sizes = [128] # Standard block size for FP8 | ||
|
|
||
| for shape in input_shapes: | ||
| for block_size in block_sizes: | ||
| configs.append( | ||
| ExperimentConfig( | ||
| input_shape=shape, | ||
| block_size=block_size, | ||
| ) | ||
| ) | ||
| return configs | ||
|
|
||
|
|
||
| def run_experiment(config: ExperimentConfig) -> ExperimentResult: | ||
| M, K = config.input_shape | ||
| block_size = config.block_size | ||
|
|
||
| def verify_outputs( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in these various this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance) |
||
| y_torch: torch.Tensor, | ||
| s_torch: torch.Tensor, | ||
| y_triton: torch.Tensor, | ||
| s_triton: torch.Tensor, | ||
| rtol: float = 1e-2, | ||
| atol: float = 1e-2, | ||
| ): | ||
| """Verify that Triton and torch implementations produce similar results.""" | ||
|
|
||
| # Convert FP8 back to float for comparison | ||
| y_torch_float = y_torch.to(torch.float32) | ||
| y_triton_float = y_triton.to(torch.float32) | ||
|
|
||
| assert y_torch.shape == y_triton.shape, ( | ||
| f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}" | ||
| ) | ||
| assert y_torch.stride() == y_triton.stride(), ( | ||
| f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}" | ||
| ) | ||
|
|
||
| assert s_torch.shape == s_triton.shape, ( | ||
| f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}" | ||
| ) | ||
| assert s_torch.stride() == s_triton.stride(), ( | ||
| f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}" | ||
| ) | ||
|
|
||
| # Check quantized values are close | ||
|
|
||
| torch.testing.assert_close( | ||
| y_torch_float, | ||
| y_triton_float, | ||
| rtol=rtol, | ||
| atol=atol, | ||
| msg="Quantized values differ between torch and Triton implementations", | ||
| ) | ||
|
|
||
| torch.testing.assert_close( | ||
| s_torch, | ||
| s_triton, | ||
| rtol=rtol, | ||
| atol=atol, | ||
| msg="Scales differ between torch and Triton implementations", | ||
| ) | ||
|
|
||
| input_tensor = torch.randn( | ||
| M, | ||
| K, | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
|
|
||
| # Benchmark torch implementation | ||
| torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) | ||
| y_torch, s_torch = torch_impl_c(input_tensor, block_size) | ||
| torch_time_us = benchmark_cuda_function_in_microseconds( | ||
| torch_impl_c, | ||
| input_tensor, | ||
| block_size, | ||
| ) | ||
|
|
||
| # Benchmark Triton implementation | ||
| y_triton, s_triton = triton_fp8_blockwise_act_quant_lhs(input_tensor, block_size) | ||
| triton_time_us = benchmark_cuda_function_in_microseconds( | ||
| triton_fp8_blockwise_act_quant_lhs, | ||
| input_tensor, | ||
| block_size, | ||
| ) | ||
|
|
||
| # Verify correctness (optional, can comment out for pure benchmarking) | ||
| verify_outputs(y_torch, s_torch, y_triton, s_triton) | ||
|
|
||
| # Memory bandwidth calculations | ||
| bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8 | ||
| bytes_per_output_el = torch.finfo(y_triton.dtype).bits / 8 | ||
| bytes_per_scale_el = torch.finfo(s_triton.dtype).bits / 8 | ||
|
|
||
| read_bytes = input_tensor.numel() * bytes_per_input_el | ||
| write_bytes = ( | ||
| y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el | ||
| ) | ||
|
|
||
| torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) | ||
| triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) | ||
|
|
||
| return ExperimentResult( | ||
| torch_us=torch_time_us, | ||
| triton_us=triton_time_us, | ||
| torch_gbps=torch_gbps, | ||
| triton_gbps=triton_gbps, | ||
| ) | ||
|
|
||
|
|
||
| def print_results(experiments: List[Experiment]): | ||
| headers = [ | ||
| "input_shape (M, K)", | ||
| "block_size", | ||
| "torch_us", | ||
| "triton_us", | ||
| "speedup", | ||
| "torch_gbps", | ||
| "triton_gbps", | ||
| ] | ||
| rows = [] | ||
| for experiment in experiments: | ||
| speedup = experiment.result.torch_us / experiment.result.triton_us | ||
| rows.append( | ||
| [ | ||
| f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", | ||
| experiment.config.block_size, | ||
| f"{experiment.result.torch_us:.2f}", | ||
| f"{experiment.result.triton_us:.2f}", | ||
| f"{speedup:.2f}x", | ||
| f"{experiment.result.torch_gbps:.1f}", | ||
| f"{experiment.result.triton_gbps:.1f}", | ||
| ] | ||
| ) | ||
| print(tabulate(rows, headers=headers, tablefmt="grid")) | ||
|
|
||
|
|
||
| def main(): | ||
| torch.random.manual_seed(123) | ||
| configs = get_configs() | ||
| results = [] | ||
|
|
||
| print(f"Running {len(configs)} benchmark configurations...\n") | ||
|
|
||
| for config in tqdm(configs, desc="Benchmarking"): | ||
| result = run_experiment(config) | ||
| results.append(Experiment(config=config, result=result)) | ||
|
|
||
| print("\n" + "=" * 80) | ||
| print("BENCHMARK RESULTS") | ||
| print("=" * 80 + "\n") | ||
| print_results(results) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make the leading total_M dims (
seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.same for act_quant_rhs benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, any downside to having all the above quantization benchmarks with these bigger values?