Skip to content

Commit 13595c5

Browse files
authored
fix mxfp8 matmul benchmark (#3221)
Summary: Adds padding to the scales to properly support shapes where M % 128 != 0 Test Plan: ``` python benchmarks/float8/bench_matmul.py --shape_gen_name custom --recipe mxfp8_cublas --M 17 --K 32 --N 16 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 82ae011 commit 13595c5

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torchao.ops import mx_fp4_bf16
1919
from torchao.prototype.mx_formats.mx_tensor import to_mx
20+
from torchao.prototype.mx_formats.utils import to_blocked
2021
from torchao.testing.training.roofline_utils import get_specs
2122
from torchao.utils import is_MI300
2223

@@ -125,10 +126,16 @@ def run(
125126
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
126127
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
127128
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
129+
# pad if needed
130+
scale_a = to_blocked(scale_a)
131+
scale_b = to_blocked(scale_b)
128132
elif recipe == "nvfp4":
129133
# Use the blockwise scales from nvfp4_quantize
130134
scale_a = A_scales.view(torch.float8_e4m3fn)
131135
scale_b = B_scales.view(torch.float8_e4m3fn)
136+
# pad if needed
137+
scale_a = to_blocked(scale_a)
138+
scale_b = to_blocked(scale_b)
132139
else:
133140
assert False, f"unknown recipe {recipe}"
134141

0 commit comments

Comments
 (0)