99import sympy
1010import torch
1111
12+ BYTES_PER_EL_FLOAT4 = 0.5
1213BYTES_PER_EL_FLOAT8 = 1
1314BYTES_PER_EL_BF16 = 2
1415
@@ -190,16 +191,24 @@ def get_tensor_memory_traffic_ovhd_s(
190191 "mxfp8_emulated" ,
191192 "mxfp8_cublas" ,
192193 "mxfp8_cublas_rceil" ,
194+ "mxfp4_cutlass" ,
193195 ), "unsupported"
194196 # For now, assume that we can't profitably fuse kernel 1 and kernel 2
195197 # x_bf16 = ...
196198 # kernel 1: x_bf16 -> x_mxfp8_dim0
197199 # kernel 2: x_bf16 -> x_mxfp8_dim1
198- if fuse_with_prev :
199- kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+ if mx_recipe_name == "mxfp4_cutlass" :
201+ if fuse_with_prev :
202+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT4 * numel
203+ else :
204+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
205+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
200206 else :
201- kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202- kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
207+ if fuse_with_prev :
208+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
209+ else :
210+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
211+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
203212 res_bytes = [kernel_1_rw , kernel_2_rw ]
204213
205214 # convert from bytes to seconds
@@ -229,6 +238,8 @@ def get_individual_gemm_time_sympy(
229238 peak_tops = specs ["bf16_peak_tops" ]
230239 elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
231240 peak_tops = specs ["fp8_peak_tops" ]
241+ elif dtype is torch .float4_e2m1fn_x2 :
242+ peak_tops = specs ["fp4_peak_tops" ]
232243 else :
233244 assert False , "unsupported"
234245 compute_gemm_time_s = gemm_ops / peak_tops / specs ["pct_achievable_gemm_tops" ]
@@ -242,8 +253,13 @@ def get_individual_gemm_time_sympy(
242253 "mxfp8_emulated" ,
243254 "mxfp8_cublas" ,
244255 "mxfp8_cublas_rceil" ,
256+ "mxfp4_cutlass" ,
257+ ), "unsupported"
258+ assert dtype in (
259+ torch .float8_e4m3fn ,
260+ torch .float8_e5m2 ,
261+ torch .float4_e2m1fn_x2 ,
245262 ), "unsupported"
246- assert dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ), "unsupported"
247263 # adjust reads for MX scaling
248264 block_size = 32
249265 num_scale_reads = num_reads // block_size
@@ -255,6 +271,8 @@ def get_individual_gemm_time_sympy(
255271 elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
256272 # read in float8, output in bfloat16
257273 bytes_rw = num_reads * BYTES_PER_EL_FLOAT8 + num_writes * BYTES_PER_EL_BF16
274+ elif dtype is torch .float4_e2m1fn_x2 :
275+ bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
258276 else :
259277 assert False , "unsupported"
260278 mem_gemm_time_s = (
0 commit comments