Skip to content

Commit d1a7fbc

Browse files
authored
extend inference roofline with real benchmarks (#3194)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 7211104 commit d1a7fbc

File tree

1 file changed

+81
-25
lines changed

1 file changed

+81
-25
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
)
3939

4040
import torchao
41+
from torchao.prototype.mx_formats.config import (
42+
MXGemmKernelChoice,
43+
)
44+
from torchao.prototype.mx_formats.inference_workflow import (
45+
MXFPInferenceConfig,
46+
NVFP4InferenceConfig,
47+
NVFP4MMConfig,
48+
)
4149
from torchao.quantization.quant_api import (
4250
Float8DynamicActivationFloat8WeightConfig,
4351
PerRow,
@@ -80,40 +88,67 @@ def get_gemm_times(
8088
fast_accum: bool,
8189
recipe_name: Optional[str],
8290
):
83-
assert recipe_name in {"rowwise"}, (
84-
"Only support real benchmarks for 'rowwise' recipe for now"
85-
)
8691
device = torch.device("cuda")
8792

8893
# bf16 time
8994
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
90-
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
9195
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
9296

9397
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
9498

95-
e4m3_dtype = torch.float8_e4m3fn
96-
if torch.version.hip and torch.cuda.is_available() and is_MI300():
97-
e4m3_dtype = torch.float8_e4m3fnuz
98-
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
99-
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
100-
B = (
101-
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
102-
.view(d2)
103-
.t()
104-
.contiguous()
105-
.t()
106-
)
99+
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
100+
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
101+
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
102+
d1
103+
)
104+
B = (
105+
torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8)
106+
.t()
107+
.contiguous()
108+
.t()
109+
.view(d2)
110+
)
111+
else:
112+
e4m3_dtype = torch.float8_e4m3fn
113+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
114+
e4m3_dtype = torch.float8_e4m3fnuz
115+
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
116+
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
117+
B = (
118+
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
119+
.view(d2)
120+
.t()
121+
.contiguous()
122+
.t()
123+
)
124+
107125
if recipe_name == "rowwise":
108126
scale_a = torch.ones(M, 1, device=device)
109127
scale_b = torch.ones(1, N, device=device)
128+
elif recipe_name == "mxfp8_cublas":
129+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
130+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
131+
elif recipe_name == "mxfp4_cutlass":
132+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
133+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
134+
elif recipe_name == "nvfp4":
135+
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
136+
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
137+
110138
else:
111139
assert False, "unsupported"
112140

113141
def do_matmul(A, B):
114-
return torch._scaled_mm(
115-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
116-
)
142+
if recipe_name == "mxfp4_cutlass":
143+
return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b)
144+
if recipe_name == "nvfp4":
145+
return torch._scaled_mm(
146+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
147+
)
148+
else:
149+
return torch._scaled_mm(
150+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
151+
)
117152

118153
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
119154

@@ -259,12 +294,33 @@ def run(
259294
# get the float8 dynamic scaling gpu kernel time
260295
torch._dynamo.reset()
261296

262-
config = Float8DynamicActivationFloat8WeightConfig(
263-
granularity=PerRow(),
264-
# for now, use TORCH. In the future might be interesting
265-
# to benchmark AUTO and FBGEMM.
266-
kernel_preference=KernelPreference.TORCH,
267-
)
297+
if recipe_name == "rowwise":
298+
config = Float8DynamicActivationFloat8WeightConfig(
299+
granularity=PerRow(),
300+
# for now, use TORCH. In the future might be interesting
301+
# to benchmark AUTO and FBGEMM.
302+
kernel_preference=KernelPreference.TORCH,
303+
)
304+
elif recipe_name == "mxfp8_cublas":
305+
config = MXFPInferenceConfig(
306+
activation_dtype=torch.float8_e4m3fn,
307+
weight_dtype=torch.float8_e4m3fn,
308+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
309+
)
310+
elif recipe_name == "mxfp4_cutlass":
311+
config = MXFPInferenceConfig(
312+
activation_dtype=torch.float4_e2m1fn_x2,
313+
weight_dtype=torch.float4_e2m1fn_x2,
314+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
315+
)
316+
elif recipe_name == "nvfp4":
317+
config = NVFP4InferenceConfig(
318+
mm_config=NVFP4MMConfig.DYNAMIC,
319+
use_dynamic_per_tensor_scale=False,
320+
)
321+
else:
322+
assert False, "unsupported"
323+
268324
m_fp8_dyn = copy.deepcopy(m_orig)
269325
quantize_(m_fp8_dyn, config)
270326

0 commit comments

Comments
 (0)