Skip to content

Commit 42fc6bd

Browse files
jerryzh168jerryzh
authored andcommitted
Add support for e2e benchmark for conv2d/conv3d (#3329)
Summary: att, we added this to float8_inference_roofline to reuse code but we haven't enabled the roofline feature. For now we just need the e2e speedup time for single conv2d/conv3d against bf16 to understand the speedup expecatation Also added B200 hardware spec. Test Plan: python $SCRIPT_PATH $OUTPUT_FILE \ --recipe_name $RECIPE_NAME \ --shape_gen_name $SHAPE_GEN_NAME \ --M $M --K $K --N $N \ --D $D --H $H --W $W \ --kernel_size $kernel_size \ --op_name conv3d This doesn't run yet because OSS fbgemm can't be installed in the B200 machine Reviewers: Subscribers: Tasks: Tags: Co-authored-by: jerryzh <jerryzh@devgpu009.kcm2.facebook.com>
1 parent 726607d commit 42fc6bd

File tree

2 files changed

+177
-65
lines changed

2 files changed

+177
-65
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 161 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from torchao.quantization.quant_api import (
5151
Float8DynamicActivationFloat8WeightConfig,
5252
PerRow,
53+
PerTensor,
5354
quantize_,
5455
)
5556
from torchao.quantization.quantize_.common import KernelPreference
@@ -179,6 +180,11 @@ def run(
179180
n_limit: Optional[int] = None,
180181
save_profile_traces: bool = False,
181182
enable_fusion_modeling: bool = False,
183+
op_name: str = "linear",
184+
D: Optional[int] = None,
185+
H: Optional[int] = None,
186+
W: Optional[int] = None,
187+
kernel_size: Optional[int] = None,
182188
):
183189
"""
184190
Args:
@@ -189,7 +195,29 @@ def run(
189195
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
190196
# `save_profile_traces (optional)`: if True, saves profiling traces
191197
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
198+
# `op_name`: linear, conv2d or conv3d, decides which op to benchmark
199+
# `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d
200+
# `kernel_size`: kernel_size for conv3d / conv2d
192201
"""
202+
_SUPPORTED_OPS = ["linear", "conv2d", "conv3d"]
203+
assert op_name in _SUPPORTED_OPS, (
204+
f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}"
205+
)
206+
if op_name == "conv2d":
207+
assert H is not None and W is not None, (
208+
"Expected D, H, W to be specified for conv2d"
209+
)
210+
assert kernel_size is not None, (
211+
"Expected kernel_size to be specified for conv2d"
212+
)
213+
elif op_name == "conv3d":
214+
assert D is not None and H is not None and W is not None, (
215+
"Expected D, H, W to be specified for conv3d"
216+
)
217+
assert kernel_size is not None, (
218+
"Expected kernel_size to be specified for conv3d"
219+
)
220+
193221
config_table = [
194222
["GPU", torch.cuda.get_device_name(0)],
195223
["torch version", torch.__version__],
@@ -198,7 +226,10 @@ def run(
198226
["do_benchmarks", do_benchmarks],
199227
["shape_gen_name", shape_gen_name],
200228
["enable_fusion_modeling", enable_fusion_modeling],
229+
["op_name", op_name],
201230
["MKN", f"{M} {K} {N}"],
231+
["DHW", f"{D} {H} {W}"],
232+
["kernel_size", kernel_size],
202233
]
203234
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
204235

@@ -207,33 +238,45 @@ def run(
207238

208239
M, K, N = sympy.symbols("M K N")
209240

210-
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
211-
M,
212-
K,
213-
N,
214-
recipe_name,
215-
# TODO(future): also enable fusion modeling here
216-
)
217-
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)
218-
219-
if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")):
220-
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
221-
M, K, N, torch.float4_e2m1fn_x2, recipe_name
241+
if op_name == "linear":
242+
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
243+
M,
244+
K,
245+
N,
246+
recipe_name,
247+
# TODO(future): also enable fusion modeling here
222248
)
223-
else:
224-
gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None
225-
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
226-
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
249+
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(
250+
M, K, N, torch.bfloat16, None
227251
)
228-
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
229-
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
230-
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
231-
print()
232252

253+
if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")):
254+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
255+
M, K, N, torch.float4_e2m1fn_x2, recipe_name
256+
)
257+
else:
258+
gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None
259+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
260+
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
261+
)
262+
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
263+
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
264+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
265+
print()
266+
else:
267+
# TODO: enable roofline analysis for conv
268+
pass
269+
270+
# Note: roofline for conv2d/conv3d is not added yet, so most of the
271+
# things for conv2d/conv3d we'll left out for now
233272
headers = [
234273
"fwd_M",
235274
"fwd_K",
236275
"fwd_N",
276+
"D",
277+
"H",
278+
"W",
279+
"kernel_size",
237280
# roofline - gemm time (fwd + bwd, 3 gemms)
238281
"r_bf16_gemm_s",
239282
"r_fp8_gemm_s",
@@ -258,6 +301,7 @@ def run(
258301
"rb_bf16_gemm_ratio",
259302
"rb_fp8_gemm_ratio",
260303
]
304+
261305
results = []
262306

263307
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N)
@@ -266,54 +310,93 @@ def run(
266310
if n_limit is not None and idx >= n_limit:
267311
break
268312

269-
# use roofline model to estimate gemm time
270-
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
271-
r_bf16_gemm_time_s = float(
272-
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
273-
)
274-
r_fp8_gemm_time_s = float(
275-
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
276-
)
277-
278-
# if enabled, also measured observed gemm time
279-
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
280-
rb_bf16_gemm_ratio = -1
281-
rb_fp8_gemm_ratio = -1
313+
if op_name == "linear":
314+
# use roofline model to estimate gemm time
315+
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
316+
r_bf16_gemm_time_s = float(
317+
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
318+
)
319+
r_fp8_gemm_time_s = float(
320+
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
321+
)
282322

283-
if do_benchmarks:
284-
# TODO(future): make the bf16 gemm times exactly match the e2e
285-
# benchmarks, there is a slight deviation, probably related to gemm
286-
# operand memory formats/transpositions below not exactly matching
287-
# what PyTorch core is doing for `torch.mm`
288-
# input @ weight_t = output
289-
bf16_g1, f8_g1 = get_gemm_times(
290-
M_val,
291-
K_val,
292-
N_val,
293-
True,
294-
recipe_name,
323+
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
324+
r_fp8_ovhd_time_s = float(
325+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
295326
)
296-
b_bf16_gemm_time_s = bf16_g1
297-
b_fp8_gemm_time_s = f8_g1
298-
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
299-
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
300-
301-
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
302-
r_fp8_ovhd_time_s = float(
303-
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
304-
)
327+
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
328+
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
329+
330+
# if enabled, also measured observed gemm time
331+
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
332+
rb_bf16_gemm_ratio = -1
333+
rb_fp8_gemm_ratio = -1
334+
335+
if do_benchmarks:
336+
# TODO(future): make the bf16 gemm times exactly match the e2e
337+
# benchmarks, there is a slight deviation, probably related to gemm
338+
# operand memory formats/transpositions below not exactly matching
339+
# what PyTorch core is doing for `torch.mm`
340+
# input @ weight_t = output
341+
bf16_g1, f8_g1 = get_gemm_times(
342+
M_val,
343+
K_val,
344+
N_val,
345+
True,
346+
recipe_name,
347+
)
348+
b_bf16_gemm_time_s = bf16_g1
349+
b_fp8_gemm_time_s = f8_g1
350+
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
351+
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
352+
353+
else:
354+
# roofline analysis for conv2d/conv3d are not added yet
355+
r_bf16_gemm_time_s = None
356+
r_fp8_gemm_time_s = None
357+
358+
r_fp8_ovhd_time_s = None
359+
r_fp8_gemm_and_ovhd_s = None
360+
r_speedup = None
361+
362+
# real gemm benchmark time, also not added yet
363+
# if enabled, also measured observed gemm time
364+
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
365+
# gemm roofline ratio achieved in real benchmark
366+
rb_bf16_gemm_ratio = -1
367+
rb_fp8_gemm_ratio = -1
305368

306369
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
307370
if do_benchmarks:
308371
# create the model
309-
if not enable_fusion_modeling:
310-
m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False))
372+
if op_name == "conv2d":
373+
m_orig = nn.Sequential(
374+
nn.Conv2d(K_val, N_val, kernel_size, bias=False)
375+
).to(memory_format=torch.channels_last)
376+
elif op_name == "conv3d":
377+
m_orig = nn.Sequential(
378+
nn.Conv3d(K_val, N_val, kernel_size, bias=False)
379+
).to(memory_format=torch.channels_last_3d)
311380
else:
312-
m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False))
381+
if not enable_fusion_modeling:
382+
m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False))
383+
else:
384+
m_orig = nn.Sequential(
385+
nn.ReLU(), nn.Linear(K_val, N_val, bias=False)
386+
)
313387
m_orig = m_orig.cuda().bfloat16()
314-
x = torch.randn(
315-
M_val, K_val, dtype=torch.bfloat16, device="cuda"
316-
).requires_grad_()
388+
if op_name == "conv2d":
389+
x = torch.randn(
390+
M_val, K_val, H, W, dtype=torch.bfloat16, device="cuda"
391+
).to(memory_format=torch.channels_last)
392+
elif op_name == "conv3d":
393+
x = torch.randn(
394+
M_val, K_val, D, H, W, dtype=torch.bfloat16, device="cuda"
395+
).to(memory_format=torch.channels_last_3d)
396+
else:
397+
x = torch.randn(
398+
M_val, K_val, dtype=torch.bfloat16, device="cuda"
399+
).requires_grad_()
317400

318401
# get the bf16 gpu kernel time
319402
torch._dynamo.reset()
@@ -327,7 +410,11 @@ def run(
327410
# get the float8 dynamic scaling gpu kernel time
328411
torch._dynamo.reset()
329412

330-
if recipe_name == "rowwise":
413+
if recipe_name == "tensorwise":
414+
config = Float8DynamicActivationFloat8WeightConfig(
415+
granularity=PerTensor(),
416+
)
417+
elif recipe_name == "rowwise":
331418
config = Float8DynamicActivationFloat8WeightConfig(
332419
granularity=PerRow(),
333420
# for now, use TORCH. In the future might be interesting
@@ -355,7 +442,14 @@ def run(
355442
assert False, "unsupported"
356443

357444
m_fp8_dyn = copy.deepcopy(m_orig)
358-
quantize_(m_fp8_dyn, config)
445+
if op_name == "linear":
446+
quantize_(m_fp8_dyn, config)
447+
elif op_name == "conv2d":
448+
_is_conv2d = lambda m, fqn: isinstance(m, torch.nn.Conv2d)
449+
quantize_(m_fp8_dyn, config, filter_fn=_is_conv2d)
450+
else:
451+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
452+
quantize_(m_fp8_dyn, config, filter_fn=_is_conv3d)
359453

360454
m_fp8_dyn = torch.compile(m_fp8_dyn)
361455

@@ -364,20 +458,22 @@ def run(
364458
fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json"
365459
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename)
366460

367-
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
368-
369461
results.append(
370462
[
371463
M_val,
372464
K_val,
373465
N_val,
466+
D,
467+
H,
468+
W,
469+
kernel_size,
374470
# roofline - gemm
375471
r_bf16_gemm_time_s,
376472
r_fp8_gemm_time_s,
377473
# roofline - fp8 overhead
378474
r_fp8_ovhd_time_s,
379475
# roofline - gemm + overhead, and speedup
380-
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
476+
r_fp8_gemm_and_ovhd_s,
381477
r_speedup,
382478
# benchmarks - gemm
383479
b_bf16_gemm_time_s,

torchao/testing/training/roofline_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@
4343
# TODO(future): measure once we have the hardware
4444
"pct_achievable_mem_bw": 0.92,
4545
},
46+
"NVIDIA GB200": {
47+
# https://resources.nvidia.com/en-us-blackwell-architecture, page 19,
48+
# divide by 2 because no sparsity
49+
"bf16_peak_tops": 2.25e15,
50+
"fp8_peak_tops": 4.5e15,
51+
"fp4_peak_tops": 9.0e15,
52+
# https://resources.nvidia.com/en-us-blackwell-architecture, page 20
53+
# 8.0 TB per second
54+
"peak_mem_bw_bytes_sec": 8.0e12,
55+
# for now, copy over from H100
56+
# TODO(future): measure once we have the hardware
57+
"pct_achievable_gemm_tops": 0.78,
58+
# for now, copy over from H100
59+
# TODO(future): measure once we have the hardware
60+
"pct_achievable_mem_bw": 0.92,
61+
},
4662
"AMD Instinct MI300X": {
4763
# https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf, page 1,
4864
"bf16_peak_tops": 1307e12,

0 commit comments

Comments
 (0)