Skip to content

Commit 6c24a7a

Browse files
authored
update roofline benchmark with mxfp4 (#3152)
Update [ghstack-poisoned]
1 parent 8c6d754 commit 6c24a7a

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ def run(
245245
bf16_gemm_time_sympy = get_gemm_time_sympy(
246246
M, K, N, torch.bfloat16, None, None, None
247247
)
248+
lowp_input_dtype = torch.float8_e4m3fn
249+
if mx_recipe_name == "mxfp4_cutlass":
250+
lowp_input_dtype = torch.float4_e2m1fn_x2
251+
248252
fp8_gemm_time_sympy = get_gemm_time_sympy(
249-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None
253+
M, K, N, lowp_input_dtype, float8_recipe_name, mx_recipe_name, None
250254
)
251255
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
252256
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
@@ -304,6 +308,8 @@ def run(
304308
rb_fp8_gemm_ratio = -1
305309

306310
if do_benchmarks:
311+
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
312+
307313
# TODO(future): make the bf16 gemm times exactly match the e2e
308314
# benchmarks, there is a slight deviation, probably related to gemm
309315
# operand memory formats/transpositions below not exactly matching

torchao/testing/training/roofline_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sympy
1010
import torch
1111

12+
BYTES_PER_EL_FLOAT4 = 0.5
1213
BYTES_PER_EL_FLOAT8 = 1
1314
BYTES_PER_EL_BF16 = 2
1415

@@ -190,16 +191,24 @@ def get_tensor_memory_traffic_ovhd_s(
190191
"mxfp8_emulated",
191192
"mxfp8_cublas",
192193
"mxfp8_cublas_rceil",
194+
"mxfp4_cutlass",
193195
), "unsupported"
194196
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
195197
# x_bf16 = ...
196198
# kernel 1: x_bf16 -> x_mxfp8_dim0
197199
# kernel 2: x_bf16 -> x_mxfp8_dim1
198-
if fuse_with_prev:
199-
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+
if mx_recipe_name == "mxfp4_cutlass":
201+
if fuse_with_prev:
202+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT4 * numel
203+
else:
204+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
205+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
200206
else:
201-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202-
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
207+
if fuse_with_prev:
208+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
209+
else:
210+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
211+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
203212
res_bytes = [kernel_1_rw, kernel_2_rw]
204213

205214
# convert from bytes to seconds
@@ -229,6 +238,8 @@ def get_individual_gemm_time_sympy(
229238
peak_tops = specs["bf16_peak_tops"]
230239
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
231240
peak_tops = specs["fp8_peak_tops"]
241+
elif dtype is torch.float4_e2m1fn_x2:
242+
peak_tops = specs["fp4_peak_tops"]
232243
else:
233244
assert False, "unsupported"
234245
compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
@@ -242,8 +253,13 @@ def get_individual_gemm_time_sympy(
242253
"mxfp8_emulated",
243254
"mxfp8_cublas",
244255
"mxfp8_cublas_rceil",
256+
"mxfp4_cutlass",
257+
), "unsupported"
258+
assert dtype in (
259+
torch.float8_e4m3fn,
260+
torch.float8_e5m2,
261+
torch.float4_e2m1fn_x2,
245262
), "unsupported"
246-
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported"
247263
# adjust reads for MX scaling
248264
block_size = 32
249265
num_scale_reads = num_reads // block_size
@@ -255,6 +271,8 @@ def get_individual_gemm_time_sympy(
255271
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
256272
# read in float8, output in bfloat16
257273
bytes_rw = num_reads * BYTES_PER_EL_FLOAT8 + num_writes * BYTES_PER_EL_BF16
274+
elif dtype is torch.float4_e2m1fn_x2:
275+
bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
258276
else:
259277
assert False, "unsupported"
260278
mem_gemm_time_s = (

0 commit comments

Comments
 (0)