Skip to content

Commit beee153

Browse files
authored
enable custom MKN in inference roofline script (#3224)
Summary: Small refactor to enable user to pass in custom MKN to inference roofline script Test Plan: ``` python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/test.csv --recipe_name mxfp8_cublas --shape_gen_name custom --M 3072 --K 3072 --N 3072 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 6452b4a commit beee153

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
NVFP4InferenceConfig,
4747
NVFP4MMConfig,
4848
)
49+
from torchao.prototype.mx_formats.utils import to_blocked
4950
from torchao.quantization.quant_api import (
5051
Float8DynamicActivationFloat8WeightConfig,
5152
PerRow,
@@ -134,12 +135,18 @@ def get_gemm_times(
134135
elif recipe_name == "mxfp8_cublas":
135136
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
136137
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
138+
scale_a = to_blocked(scale_a)
139+
scale_b = to_blocked(scale_b)
137140
elif recipe_name == "mxfp4_cutlass":
138141
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
139142
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
143+
scale_a = to_blocked(scale_a)
144+
scale_b = to_blocked(scale_b)
140145
elif recipe_name == "nvfp4":
141146
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
142147
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
148+
scale_a = to_blocked(scale_a)
149+
scale_b = to_blocked(scale_b)
143150

144151
else:
145152
assert False, "unsupported"
@@ -166,6 +173,9 @@ def run(
166173
recipe_name: str,
167174
do_benchmarks: bool = True,
168175
shape_gen_name: str = "pow2",
176+
M: Optional[int] = None,
177+
K: Optional[int] = None,
178+
N: Optional[int] = None,
169179
n_limit: Optional[int] = None,
170180
save_profile_traces: bool = False,
171181
enable_fusion_modeling: bool = False,
@@ -174,7 +184,8 @@ def run(
174184
Args:
175185
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
176186
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
177-
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
187+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
188+
* `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
178189
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
179190
# `save_profile_traces (optional)`: if True, saves profiling traces
180191
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
@@ -187,9 +198,13 @@ def run(
187198
["do_benchmarks", do_benchmarks],
188199
["shape_gen_name", shape_gen_name],
189200
["enable_fusion_modeling", enable_fusion_modeling],
201+
["MKN", f"{M} {K} {N}"],
190202
]
191203
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
192204

205+
# reassign user specified MKN, so we can use them for sympy
206+
user_M, user_K, user_N = M, K, N
207+
193208
M, K, N = sympy.symbols("M K N")
194209

195210
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
@@ -245,7 +260,7 @@ def run(
245260
]
246261
results = []
247262

248-
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None)
263+
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N)
249264

250265
for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)):
251266
if n_limit is not None and idx >= n_limit:

0 commit comments

Comments
 (0)