Skip to content

Commit d43090a

Browse files
authored
[benchmarks] Add inference-only roofline for float8 (#3167)
add infernece only roofline
1 parent ce9f009 commit d43090a

File tree

4 files changed

+99
-46
lines changed

4 files changed

+99
-46
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ aten/build/
3434
aten/src/ATen/Config.h
3535
aten/src/ATen/cuda/CUDAConfig.h
3636
benchmarks/.data
37+
benchmarks/data
3738
caffe2/cpp_test/
3839
dist/
3940
docs/build/

benchmarks/float8/float8_inference_roofline.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import torch
3030
import torch.nn as nn
3131
import tqdm
32+
from tabulate import tabulate
3233
from torch.profiler import ProfilerActivity, profile
3334
from utils import (
3435
get_gpu_kernel_gemm_time_s,
@@ -77,8 +78,11 @@ def get_gemm_times(
7778
K: int,
7879
N: int,
7980
fast_accum: bool,
80-
float8_recipe_name: Optional[str],
81+
recipe_name: Optional[str],
8182
):
83+
assert recipe_name in {"rowwise"}, (
84+
"Only support real benchmarks for 'rowwise' recipe for now"
85+
)
8286
device = torch.device("cuda")
8387

8488
# bf16 time
@@ -100,7 +104,7 @@ def get_gemm_times(
100104
.contiguous()
101105
.t()
102106
)
103-
if float8_recipe_name in ("rowwise"):
107+
if recipe_name == "rowwise":
104108
scale_a = torch.ones(M, 1, device=device)
105109
scale_b = torch.ones(1, N, device=device)
106110
else:
@@ -118,41 +122,47 @@ def do_matmul(A, B):
118122

119123
def run(
120124
outfile: str,
125+
recipe_name: str,
121126
do_benchmarks: bool = True,
122127
shape_gen_name: str = "pow2",
123128
n_limit: Optional[int] = None,
124-
float8_recipe_name: Optional[str] = None,
125129
):
126130
"""
127131
Args:
132+
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
128133
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
129134
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
130135
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
131136
"""
132-
133-
assert float8_recipe_name is not None, "unsupported"
134-
135-
print(f"GPU: {torch.cuda.get_device_name(0)}")
136-
print(f"torch version: {torch.__version__}")
137-
print(f"torchao version: {torchao.__version__}")
138-
print(f"do_benchmarks: {do_benchmarks}")
139-
print(f"shape_gen_name: {shape_gen_name}")
140-
print(f"float8_recipe_name: {float8_recipe_name}")
137+
config_table = [
138+
["GPU", torch.cuda.get_device_name(0)],
139+
["torch version", torch.__version__],
140+
["torchao version", torchao.__version__],
141+
["recipe_name", recipe_name],
142+
["do_benchmarks", do_benchmarks],
143+
["shape_gen_name", shape_gen_name],
144+
]
145+
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
141146

142147
M, K, N = sympy.symbols("M K N")
143148

144149
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
145150
M,
146151
K,
147152
N,
148-
float8_recipe_name,
149-
)
150-
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(
151-
M, K, N, torch.bfloat16, None, None
152-
)
153-
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
154-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, None
153+
recipe_name,
155154
)
155+
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)
156+
157+
if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")):
158+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
159+
M, K, N, torch.float4_e2m1fn_x2, recipe_name
160+
)
161+
else:
162+
gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None
163+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
164+
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
165+
)
156166
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
157167
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
158168
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
@@ -219,7 +229,7 @@ def run(
219229
K_val,
220230
N_val,
221231
True,
222-
float8_recipe_name,
232+
recipe_name,
223233
)
224234
b_bf16_gemm_time_s = bf16_g1
225235
b_fp8_gemm_time_s = f8_g1
@@ -261,6 +271,8 @@ def run(
261271
m_fp8_dyn = torch.compile(m_fp8_dyn)
262272
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
263273

274+
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
275+
264276
results.append(
265277
[
266278
M_val,
@@ -273,7 +285,7 @@ def run(
273285
r_fp8_ovhd_time_s,
274286
# roofline - gemm + overhead, and speedup
275287
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
276-
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
288+
r_speedup,
277289
# benchmarks - gemm
278290
b_bf16_gemm_time_s,
279291
b_fp8_gemm_time_s,

benchmarks/float8/float8_roofline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def run(
214214
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
215215
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
216216
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
217+
* `mx_recipe_name (optional)`: MX format recipe
217218
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
218219
"""
219220

torchao/testing/training/roofline_utils.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
BYTES_PER_EL_FLOAT4 = 0.5
1313
BYTES_PER_EL_FLOAT8 = 1
1414
BYTES_PER_EL_BF16 = 2
15+
BYTES_PER_EL_FLOAT32 = 4
1516

1617
gpu_name_to_specs = {
1718
"NVIDIA H100": {
@@ -228,7 +229,7 @@ def get_individual_gemm_time_sympy(
228229
K: sympy.Symbol,
229230
N: sympy.Symbol,
230231
dtype,
231-
mx_recipe_name,
232+
mx_recipe_name: Optional[str],
232233
gpu_name: Optional[str] = None,
233234
) -> sympy.Symbol:
234235
# compute bound
@@ -241,27 +242,24 @@ def get_individual_gemm_time_sympy(
241242
elif dtype is torch.float4_e2m1fn_x2:
242243
peak_tops = specs["fp4_peak_tops"]
243244
else:
244-
assert False, "unsupported"
245+
assert False, f"unsupported dtype: {dtype}"
245246
compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
246247

247248
# memory bound
248249
num_reads = M * K + K * N
249250
num_writes = M * N
250251

251252
if mx_recipe_name is not None:
252-
assert mx_recipe_name in (
253-
"mxfp8_emulated",
254-
"mxfp8_cublas",
255-
"mxfp8_cublas_rceil",
256-
"mxfp4_cutlass",
257-
), "unsupported"
253+
assert mx_recipe_name.startswith(("mxfp8", "mxfp4", "nvfp4")), (
254+
f"Unsupported recipe {mx_recipe_name}"
255+
)
258256
assert dtype in (
259257
torch.float8_e4m3fn,
260258
torch.float8_e5m2,
261259
torch.float4_e2m1fn_x2,
262260
), "unsupported"
263261
# adjust reads for MX scaling
264-
block_size = 32
262+
block_size = 32 if mx_recipe_name.startswith("mx") else 16
265263
num_scale_reads = num_reads // block_size
266264
# note: e8m0 bytes per element is the same as for e4m3|e5m2
267265
num_reads = num_reads + num_scale_reads
@@ -274,7 +272,7 @@ def get_individual_gemm_time_sympy(
274272
elif dtype is torch.float4_e2m1fn_x2:
275273
bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
276274
else:
277-
assert False, "unsupported"
275+
assert False, f"unsupported dtype: {dtype}"
278276
mem_gemm_time_s = (
279277
bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
280278
)
@@ -375,28 +373,68 @@ def get_inference_tensor_memory_traffic_ovhd_s(
375373
dim0,
376374
dim1,
377375
tensor_role: str,
378-
float8_recipe_name: Optional[str],
376+
recipe_name: Optional[str],
379377
fuse_with_prev=False,
380378
) -> List[Union[sympy.Symbol, float]]:
381379
"""
382380
Inference version of `get_tensor_memory_traffic_ovhd_s`.
383381
The only thing happening here is we quantize the activation.
384382
"""
385-
assert float8_recipe_name == "rowwise", "unsupported"
386383
assert fuse_with_prev is False, "unsupported"
384+
assert tensor_role == "input", "inference only quantizes input activations"
387385

388386
# assumes input bf16, output f8
389387
numel = dim0 * dim1
390388

391389
res_bytes = None
392390

393-
assert tensor_role == "input"
394-
# x_bf16 = ...
395-
# kernel 1: x_bf16 -> x_fp8
396-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
397-
res_bytes = [
398-
kernel_1_rw,
399-
]
391+
allowed_recipes = {"tensorwise", "rowwise", "mxfp8*", "nvfp4*", "mxfp4*"}
392+
393+
match recipe_name:
394+
case "tensorwise":
395+
# x_bf16 = ...
396+
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
397+
# kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
398+
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
399+
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
400+
kernel_1_rw = BYTES_PER_EL_BF16 * numel
401+
# kernel 3: read in bf16, write in float8
402+
kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
403+
res_bytes = [kernel_1_rw, kernel_3_rw]
404+
405+
case "rowwise":
406+
# x_bf16 = ...
407+
# kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
408+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
409+
# add in the bytes for scale writes
410+
kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
411+
res_bytes = [kernel_1_rw]
412+
413+
case name if name and name.startswith("mxfp8"):
414+
# x_bf16 = ...
415+
# kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
416+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
417+
# add in the bytes for scale writes in E8M0 format
418+
kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // 32)
419+
res_bytes = [kernel_1_rw]
420+
421+
case name if name and (name.startswith("mxfp4") or name.startswith("nvfp4")):
422+
# For NVFP4, assume minimal overhead since it's primarily a compute format
423+
# x_bf16 = ...
424+
# kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
425+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
426+
if name.startswith("nvfp4"):
427+
kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
428+
# add in the bytes for scale writes in E4M3 | E8M0
429+
block_size = 32 if name.startswith("mxfp4") else 16
430+
kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // block_size)
431+
res_bytes = [kernel_1_rw]
432+
433+
case _:
434+
raise ValueError(
435+
f"Unknown recipe name: {recipe_name}. "
436+
f"Allowed recipes: {allowed_recipes}"
437+
)
400438

401439
# convert from bytes to seconds
402440
res_s = [
@@ -414,7 +452,7 @@ def get_inference_float8_mem_sympy(
414452
M,
415453
K,
416454
N,
417-
float8_recipe_name: Optional[str],
455+
recipe_name: Optional[str],
418456
gpu_name: Optional[str] = None,
419457
):
420458
specs = get_specs(gpu_name)
@@ -425,7 +463,7 @@ def get_inference_float8_mem_sympy(
425463
M,
426464
K,
427465
tensor_role="input",
428-
float8_recipe_name=float8_recipe_name,
466+
recipe_name=recipe_name,
429467
fuse_with_prev=False,
430468
)
431469
res = sum([*fwd_fp8_input_mem])
@@ -437,11 +475,12 @@ def get_inference_gemm_time_sympy(
437475
K: sympy.Symbol,
438476
N: sympy.Symbol,
439477
dtype,
440-
float8_recipe_name: Optional[str],
441-
gpu_name: Optional[str],
478+
recipe_name: Optional[str],
479+
gpu_name: Optional[str] = None,
442480
):
443-
assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported"
444481
# note: this function is currently not super accurate for small shapes:
445482
# when M,K,N <= 1k,1k,1k it undercounts by around 2x
446-
gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name)
483+
gemm_output_time_s = get_individual_gemm_time_sympy(
484+
M, K, N, dtype, recipe_name, gpu_name
485+
)
447486
return gemm_output_time_s

0 commit comments

Comments
 (0)