4646 NVFP4InferenceConfig ,
4747 NVFP4MMConfig ,
4848)
49+ from torchao .prototype .mx_formats .utils import to_blocked
4950from torchao .quantization .quant_api import (
5051 Float8DynamicActivationFloat8WeightConfig ,
5152 PerRow ,
@@ -134,12 +135,18 @@ def get_gemm_times(
134135 elif recipe_name == "mxfp8_cublas" :
135136 scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
136137 scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
138+ scale_a = to_blocked (scale_a )
139+ scale_b = to_blocked (scale_b )
137140 elif recipe_name == "mxfp4_cutlass" :
138141 scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
139142 scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
143+ scale_a = to_blocked (scale_a )
144+ scale_b = to_blocked (scale_b )
140145 elif recipe_name == "nvfp4" :
141146 scale_a = torch .ones (M , K // 16 , device = device , dtype = torch .float8_e4m3fn )
142147 scale_b = torch .ones (N , K // 16 , device = device , dtype = torch .float8_e4m3fn )
148+ scale_a = to_blocked (scale_a )
149+ scale_b = to_blocked (scale_b )
143150
144151 else :
145152 assert False , "unsupported"
@@ -166,6 +173,9 @@ def run(
166173 recipe_name : str ,
167174 do_benchmarks : bool = True ,
168175 shape_gen_name : str = "pow2" ,
176+ M : Optional [int ] = None ,
177+ K : Optional [int ] = None ,
178+ N : Optional [int ] = None ,
169179 n_limit : Optional [int ] = None ,
170180 save_profile_traces : bool = False ,
171181 enable_fusion_modeling : bool = False ,
@@ -174,7 +184,8 @@ def run(
174184 Args:
175185 * `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
176186 * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
177- * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
187+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
188+ * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
178189 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
179190 # `save_profile_traces (optional)`: if True, saves profiling traces
180191 # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
@@ -187,9 +198,13 @@ def run(
187198 ["do_benchmarks" , do_benchmarks ],
188199 ["shape_gen_name" , shape_gen_name ],
189200 ["enable_fusion_modeling" , enable_fusion_modeling ],
201+ ["MKN" , f"{ M } { K } { N } " ],
190202 ]
191203 print (tabulate (config_table , headers = ["Parameter" , "Value" ], tablefmt = "simple" ))
192204
205+ # reassign user specified MKN, so we can use them for sympy
206+ user_M , user_K , user_N = M , K , N
207+
193208 M , K , N = sympy .symbols ("M K N" )
194209
195210 fp8_ovhd_time_sympy = get_inference_float8_mem_sympy (
@@ -245,7 +260,7 @@ def run(
245260 ]
246261 results = []
247262
248- name_to_shapes = get_name_to_shapes_iter (shape_gen_name , None , None , None )
263+ name_to_shapes = get_name_to_shapes_iter (shape_gen_name , user_M , user_K , user_N )
249264
250265 for idx , (name , (M_val , K_val , N_val )) in enumerate (tqdm .tqdm (name_to_shapes )):
251266 if n_limit is not None and idx >= n_limit :
0 commit comments