Skip to content

Commit 6452b4a

Browse files
authored
mxfp8 inference roofline: add fusion to observed (#3223)
Summary: Adds option to benchmark with relu -> linear to capture the impact of fusing the activation to the quant kernel Test Plan: ```bash (pt_nightly_312_2) [vasiliy@devgpu023.atn1 ~/local/ao (20251021_inference_fusion_modeling)]$ python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/test.csv --recipe_name mxfp8_cublas --shape_gen_name pow2_extended --enable_fusion_modeling True ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 13595c5 commit 6452b4a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def run(
168168
shape_gen_name: str = "pow2",
169169
n_limit: Optional[int] = None,
170170
save_profile_traces: bool = False,
171+
enable_fusion_modeling: bool = False,
171172
):
172173
"""
173174
Args:
@@ -176,6 +177,7 @@ def run(
176177
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
177178
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
178179
# `save_profile_traces (optional)`: if True, saves profiling traces
180+
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
179181
"""
180182
config_table = [
181183
["GPU", torch.cuda.get_device_name(0)],
@@ -184,6 +186,7 @@ def run(
184186
["recipe_name", recipe_name],
185187
["do_benchmarks", do_benchmarks],
186188
["shape_gen_name", shape_gen_name],
189+
["enable_fusion_modeling", enable_fusion_modeling],
187190
]
188191
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
189192

@@ -194,6 +197,7 @@ def run(
194197
K,
195198
N,
196199
recipe_name,
200+
# TODO(future): also enable fusion modeling here
197201
)
198202
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)
199203

@@ -287,9 +291,11 @@ def run(
287291
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
288292
if do_benchmarks:
289293
# create the model
290-
m_orig = (
291-
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
292-
)
294+
if not enable_fusion_modeling:
295+
m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False))
296+
else:
297+
m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False))
298+
m_orig = m_orig.cuda().bfloat16()
293299
x = torch.randn(
294300
M_val, K_val, dtype=torch.bfloat16, device="cuda"
295301
).requires_grad_()

0 commit comments

Comments
 (0)