1212BYTES_PER_EL_FLOAT4 = 0.5
1313BYTES_PER_EL_FLOAT8 = 1
1414BYTES_PER_EL_BF16 = 2
15+ BYTES_PER_EL_FLOAT32 = 4
1516
1617gpu_name_to_specs = {
1718 "NVIDIA H100" : {
@@ -228,7 +229,7 @@ def get_individual_gemm_time_sympy(
228229 K : sympy .Symbol ,
229230 N : sympy .Symbol ,
230231 dtype ,
231- mx_recipe_name ,
232+ mx_recipe_name : Optional [ str ] ,
232233 gpu_name : Optional [str ] = None ,
233234) -> sympy .Symbol :
234235 # compute bound
@@ -241,27 +242,24 @@ def get_individual_gemm_time_sympy(
241242 elif dtype is torch .float4_e2m1fn_x2 :
242243 peak_tops = specs ["fp4_peak_tops" ]
243244 else :
244- assert False , "unsupported"
245+ assert False , f "unsupported dtype: { dtype } "
245246 compute_gemm_time_s = gemm_ops / peak_tops / specs ["pct_achievable_gemm_tops" ]
246247
247248 # memory bound
248249 num_reads = M * K + K * N
249250 num_writes = M * N
250251
251252 if mx_recipe_name is not None :
252- assert mx_recipe_name in (
253- "mxfp8_emulated" ,
254- "mxfp8_cublas" ,
255- "mxfp8_cublas_rceil" ,
256- "mxfp4_cutlass" ,
257- ), "unsupported"
253+ assert mx_recipe_name .startswith (("mxfp8" , "mxfp4" , "nvfp4" )), (
254+ f"Unsupported recipe { mx_recipe_name } "
255+ )
258256 assert dtype in (
259257 torch .float8_e4m3fn ,
260258 torch .float8_e5m2 ,
261259 torch .float4_e2m1fn_x2 ,
262260 ), "unsupported"
263261 # adjust reads for MX scaling
264- block_size = 32
262+ block_size = 32 if mx_recipe_name . startswith ( "mx" ) else 16
265263 num_scale_reads = num_reads // block_size
266264 # note: e8m0 bytes per element is the same as for e4m3|e5m2
267265 num_reads = num_reads + num_scale_reads
@@ -274,7 +272,7 @@ def get_individual_gemm_time_sympy(
274272 elif dtype is torch .float4_e2m1fn_x2 :
275273 bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
276274 else :
277- assert False , "unsupported"
275+ assert False , f "unsupported dtype: { dtype } "
278276 mem_gemm_time_s = (
279277 bytes_rw / specs ["peak_mem_bw_bytes_sec" ] / specs ["pct_achievable_mem_bw" ]
280278 )
@@ -375,28 +373,68 @@ def get_inference_tensor_memory_traffic_ovhd_s(
375373 dim0 ,
376374 dim1 ,
377375 tensor_role : str ,
378- float8_recipe_name : Optional [str ],
376+ recipe_name : Optional [str ],
379377 fuse_with_prev = False ,
380378) -> List [Union [sympy .Symbol , float ]]:
381379 """
382380 Inference version of `get_tensor_memory_traffic_ovhd_s`.
383381 The only thing happening here is we quantize the activation.
384382 """
385- assert float8_recipe_name == "rowwise" , "unsupported"
386383 assert fuse_with_prev is False , "unsupported"
384+ assert tensor_role == "input" , "inference only quantizes input activations"
387385
388386 # assumes input bf16, output f8
389387 numel = dim0 * dim1
390388
391389 res_bytes = None
392390
393- assert tensor_role == "input"
394- # x_bf16 = ...
395- # kernel 1: x_bf16 -> x_fp8
396- kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
397- res_bytes = [
398- kernel_1_rw ,
399- ]
391+ allowed_recipes = {"tensorwise" , "rowwise" , "mxfp8*" , "nvfp4*" , "mxfp4*" }
392+
393+ match recipe_name :
394+ case "tensorwise" :
395+ # x_bf16 = ...
396+ # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
397+ # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
398+ # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
399+ # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
400+ kernel_1_rw = BYTES_PER_EL_BF16 * numel
401+ # kernel 3: read in bf16, write in float8
402+ kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
403+ res_bytes = [kernel_1_rw , kernel_3_rw ]
404+
405+ case "rowwise" :
406+ # x_bf16 = ...
407+ # kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
408+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
409+ # add in the bytes for scale writes
410+ kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
411+ res_bytes = [kernel_1_rw ]
412+
413+ case name if name and name .startswith ("mxfp8" ):
414+ # x_bf16 = ...
415+ # kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
416+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
417+ # add in the bytes for scale writes in E8M0 format
418+ kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // 32 )
419+ res_bytes = [kernel_1_rw ]
420+
421+ case name if name and (name .startswith ("mxfp4" ) or name .startswith ("nvfp4" )):
422+ # For NVFP4, assume minimal overhead since it's primarily a compute format
423+ # x_bf16 = ...
424+ # kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
425+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
426+ if name .startswith ("nvfp4" ):
427+ kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
428+ # add in the bytes for scale writes in E4M3 | E8M0
429+ block_size = 32 if name .startswith ("mxfp4" ) else 16
430+ kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // block_size )
431+ res_bytes = [kernel_1_rw ]
432+
433+ case _:
434+ raise ValueError (
435+ f"Unknown recipe name: { recipe_name } . "
436+ f"Allowed recipes: { allowed_recipes } "
437+ )
400438
401439 # convert from bytes to seconds
402440 res_s = [
@@ -414,7 +452,7 @@ def get_inference_float8_mem_sympy(
414452 M ,
415453 K ,
416454 N ,
417- float8_recipe_name : Optional [str ],
455+ recipe_name : Optional [str ],
418456 gpu_name : Optional [str ] = None ,
419457):
420458 specs = get_specs (gpu_name )
@@ -425,7 +463,7 @@ def get_inference_float8_mem_sympy(
425463 M ,
426464 K ,
427465 tensor_role = "input" ,
428- float8_recipe_name = float8_recipe_name ,
466+ recipe_name = recipe_name ,
429467 fuse_with_prev = False ,
430468 )
431469 res = sum ([* fwd_fp8_input_mem ])
@@ -437,11 +475,12 @@ def get_inference_gemm_time_sympy(
437475 K : sympy .Symbol ,
438476 N : sympy .Symbol ,
439477 dtype ,
440- float8_recipe_name : Optional [str ],
441- gpu_name : Optional [str ],
478+ recipe_name : Optional [str ],
479+ gpu_name : Optional [str ] = None ,
442480):
443- assert float8_recipe_name == "rowwise" or float8_recipe_name is None , "unsupported"
444481 # note: this function is currently not super accurate for small shapes:
445482 # when M,K,N <= 1k,1k,1k it undercounts by around 2x
446- gemm_output_time_s = get_individual_gemm_time_sympy (M , K , N , dtype , None , gpu_name )
483+ gemm_output_time_s = get_individual_gemm_time_sympy (
484+ M , K , N , dtype , recipe_name , gpu_name
485+ )
447486 return gemm_output_time_s
0 commit comments