5050from torchao .quantization .quant_api import (
5151 Float8DynamicActivationFloat8WeightConfig ,
5252 PerRow ,
53+ PerTensor ,
5354 quantize_ ,
5455)
5556from torchao .quantization .quantize_ .common import KernelPreference
@@ -179,6 +180,11 @@ def run(
179180 n_limit : Optional [int ] = None ,
180181 save_profile_traces : bool = False ,
181182 enable_fusion_modeling : bool = False ,
183+ op_name : str = "linear" ,
184+ D : Optional [int ] = None ,
185+ H : Optional [int ] = None ,
186+ W : Optional [int ] = None ,
187+ kernel_size : Optional [int ] = None ,
182188):
183189 """
184190 Args:
@@ -189,7 +195,29 @@ def run(
189195 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
190196 # `save_profile_traces (optional)`: if True, saves profiling traces
191197 # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
198+ # `op_name`: linear, conv2d or conv3d, decides which op to benchmark
199+ # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d
200+ # `kernel_size`: kernel_size for conv3d / conv2d
192201 """
202+ _SUPPORTED_OPS = ["linear" , "conv2d" , "conv3d" ]
203+ assert op_name in _SUPPORTED_OPS , (
204+ f"Unsupported op: { op_name } , supported are: { _SUPPORTED_OPS } "
205+ )
206+ if op_name == "conv2d" :
207+ assert H is not None and W is not None , (
208+ "Expected D, H, W to be specified for conv2d"
209+ )
210+ assert kernel_size is not None , (
211+ "Expected kernel_size to be specified for conv2d"
212+ )
213+ elif op_name == "conv3d" :
214+ assert D is not None and H is not None and W is not None , (
215+ "Expected D, H, W to be specified for conv3d"
216+ )
217+ assert kernel_size is not None , (
218+ "Expected kernel_size to be specified for conv3d"
219+ )
220+
193221 config_table = [
194222 ["GPU" , torch .cuda .get_device_name (0 )],
195223 ["torch version" , torch .__version__ ],
@@ -198,7 +226,10 @@ def run(
198226 ["do_benchmarks" , do_benchmarks ],
199227 ["shape_gen_name" , shape_gen_name ],
200228 ["enable_fusion_modeling" , enable_fusion_modeling ],
229+ ["op_name" , op_name ],
201230 ["MKN" , f"{ M } { K } { N } " ],
231+ ["DHW" , f"{ D } { H } { W } " ],
232+ ["kernel_size" , kernel_size ],
202233 ]
203234 print (tabulate (config_table , headers = ["Parameter" , "Value" ], tablefmt = "simple" ))
204235
@@ -207,33 +238,45 @@ def run(
207238
208239 M , K , N = sympy .symbols ("M K N" )
209240
210- fp8_ovhd_time_sympy = get_inference_float8_mem_sympy (
211- M ,
212- K ,
213- N ,
214- recipe_name ,
215- # TODO(future): also enable fusion modeling here
216- )
217- bf16_gemm_time_sympy = get_inference_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
218-
219- if recipe_name and recipe_name .startswith (("nvfp4" , "mxfp4" )):
220- fp8_gemm_time_sympy = get_inference_gemm_time_sympy (
221- M , K , N , torch .float4_e2m1fn_x2 , recipe_name
241+ if op_name == "linear" :
242+ fp8_ovhd_time_sympy = get_inference_float8_mem_sympy (
243+ M ,
244+ K ,
245+ N ,
246+ recipe_name ,
247+ # TODO(future): also enable fusion modeling here
222248 )
223- else :
224- gemm_recipe_name = "mxfp8" if recipe_name .startswith ("mxfp8" ) else None
225- fp8_gemm_time_sympy = get_inference_gemm_time_sympy (
226- M , K , N , torch .float8_e4m3fn , gemm_recipe_name
249+ bf16_gemm_time_sympy = get_inference_gemm_time_sympy (
250+ M , K , N , torch .bfloat16 , None
227251 )
228- print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
229- print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
230- print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
231- print ()
232252
253+ if recipe_name and recipe_name .startswith (("nvfp4" , "mxfp4" )):
254+ fp8_gemm_time_sympy = get_inference_gemm_time_sympy (
255+ M , K , N , torch .float4_e2m1fn_x2 , recipe_name
256+ )
257+ else :
258+ gemm_recipe_name = "mxfp8" if recipe_name .startswith ("mxfp8" ) else None
259+ fp8_gemm_time_sympy = get_inference_gemm_time_sympy (
260+ M , K , N , torch .float8_e4m3fn , gemm_recipe_name
261+ )
262+ print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
263+ print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
264+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
265+ print ()
266+ else :
267+ # TODO: enable roofline analysis for conv
268+ pass
269+
270+ # Note: roofline for conv2d/conv3d is not added yet, so most of the
271+ # things for conv2d/conv3d we'll left out for now
233272 headers = [
234273 "fwd_M" ,
235274 "fwd_K" ,
236275 "fwd_N" ,
276+ "D" ,
277+ "H" ,
278+ "W" ,
279+ "kernel_size" ,
237280 # roofline - gemm time (fwd + bwd, 3 gemms)
238281 "r_bf16_gemm_s" ,
239282 "r_fp8_gemm_s" ,
@@ -258,6 +301,7 @@ def run(
258301 "rb_bf16_gemm_ratio" ,
259302 "rb_fp8_gemm_ratio" ,
260303 ]
304+
261305 results = []
262306
263307 name_to_shapes = get_name_to_shapes_iter (shape_gen_name , user_M , user_K , user_N )
@@ -266,54 +310,93 @@ def run(
266310 if n_limit is not None and idx >= n_limit :
267311 break
268312
269- # use roofline model to estimate gemm time
270- # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
271- r_bf16_gemm_time_s = float (
272- bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
273- )
274- r_fp8_gemm_time_s = float (
275- fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
276- )
277-
278- # if enabled, also measured observed gemm time
279- b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
280- rb_bf16_gemm_ratio = - 1
281- rb_fp8_gemm_ratio = - 1
313+ if op_name == "linear" :
314+ # use roofline model to estimate gemm time
315+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
316+ r_bf16_gemm_time_s = float (
317+ bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
318+ )
319+ r_fp8_gemm_time_s = float (
320+ fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
321+ )
282322
283- if do_benchmarks :
284- # TODO(future): make the bf16 gemm times exactly match the e2e
285- # benchmarks, there is a slight deviation, probably related to gemm
286- # operand memory formats/transpositions below not exactly matching
287- # what PyTorch core is doing for `torch.mm`
288- # input @ weight_t = output
289- bf16_g1 , f8_g1 = get_gemm_times (
290- M_val ,
291- K_val ,
292- N_val ,
293- True ,
294- recipe_name ,
323+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
324+ r_fp8_ovhd_time_s = float (
325+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
295326 )
296- b_bf16_gemm_time_s = bf16_g1
297- b_fp8_gemm_time_s = f8_g1
298- rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
299- rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
300-
301- # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
302- r_fp8_ovhd_time_s = float (
303- fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
304- )
327+ r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
328+ r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
329+
330+ # if enabled, also measured observed gemm time
331+ b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
332+ rb_bf16_gemm_ratio = - 1
333+ rb_fp8_gemm_ratio = - 1
334+
335+ if do_benchmarks :
336+ # TODO(future): make the bf16 gemm times exactly match the e2e
337+ # benchmarks, there is a slight deviation, probably related to gemm
338+ # operand memory formats/transpositions below not exactly matching
339+ # what PyTorch core is doing for `torch.mm`
340+ # input @ weight_t = output
341+ bf16_g1 , f8_g1 = get_gemm_times (
342+ M_val ,
343+ K_val ,
344+ N_val ,
345+ True ,
346+ recipe_name ,
347+ )
348+ b_bf16_gemm_time_s = bf16_g1
349+ b_fp8_gemm_time_s = f8_g1
350+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
351+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
352+
353+ else :
354+ # roofline analysis for conv2d/conv3d are not added yet
355+ r_bf16_gemm_time_s = None
356+ r_fp8_gemm_time_s = None
357+
358+ r_fp8_ovhd_time_s = None
359+ r_fp8_gemm_and_ovhd_s = None
360+ r_speedup = None
361+
362+ # real gemm benchmark time, also not added yet
363+ # if enabled, also measured observed gemm time
364+ b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
365+ # gemm roofline ratio achieved in real benchmark
366+ rb_bf16_gemm_ratio = - 1
367+ rb_fp8_gemm_ratio = - 1
305368
306369 b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
307370 if do_benchmarks :
308371 # create the model
309- if not enable_fusion_modeling :
310- m_orig = nn .Sequential (nn .Linear (K_val , N_val , bias = False ))
372+ if op_name == "conv2d" :
373+ m_orig = nn .Sequential (
374+ nn .Conv2d (K_val , N_val , kernel_size , bias = False )
375+ ).to (memory_format = torch .channels_last )
376+ elif op_name == "conv3d" :
377+ m_orig = nn .Sequential (
378+ nn .Conv3d (K_val , N_val , kernel_size , bias = False )
379+ ).to (memory_format = torch .channels_last_3d )
311380 else :
312- m_orig = nn .Sequential (nn .ReLU (), nn .Linear (K_val , N_val , bias = False ))
381+ if not enable_fusion_modeling :
382+ m_orig = nn .Sequential (nn .Linear (K_val , N_val , bias = False ))
383+ else :
384+ m_orig = nn .Sequential (
385+ nn .ReLU (), nn .Linear (K_val , N_val , bias = False )
386+ )
313387 m_orig = m_orig .cuda ().bfloat16 ()
314- x = torch .randn (
315- M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
316- ).requires_grad_ ()
388+ if op_name == "conv2d" :
389+ x = torch .randn (
390+ M_val , K_val , H , W , dtype = torch .bfloat16 , device = "cuda"
391+ ).to (memory_format = torch .channels_last )
392+ elif op_name == "conv3d" :
393+ x = torch .randn (
394+ M_val , K_val , D , H , W , dtype = torch .bfloat16 , device = "cuda"
395+ ).to (memory_format = torch .channels_last_3d )
396+ else :
397+ x = torch .randn (
398+ M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
399+ ).requires_grad_ ()
317400
318401 # get the bf16 gpu kernel time
319402 torch ._dynamo .reset ()
@@ -327,7 +410,11 @@ def run(
327410 # get the float8 dynamic scaling gpu kernel time
328411 torch ._dynamo .reset ()
329412
330- if recipe_name == "rowwise" :
413+ if recipe_name == "tensorwise" :
414+ config = Float8DynamicActivationFloat8WeightConfig (
415+ granularity = PerTensor (),
416+ )
417+ elif recipe_name == "rowwise" :
331418 config = Float8DynamicActivationFloat8WeightConfig (
332419 granularity = PerRow (),
333420 # for now, use TORCH. In the future might be interesting
@@ -355,7 +442,14 @@ def run(
355442 assert False , "unsupported"
356443
357444 m_fp8_dyn = copy .deepcopy (m_orig )
358- quantize_ (m_fp8_dyn , config )
445+ if op_name == "linear" :
446+ quantize_ (m_fp8_dyn , config )
447+ elif op_name == "conv2d" :
448+ _is_conv2d = lambda m , fqn : isinstance (m , torch .nn .Conv2d )
449+ quantize_ (m_fp8_dyn , config , filter_fn = _is_conv2d )
450+ else :
451+ _is_conv3d = lambda m , fqn : isinstance (m , torch .nn .Conv3d )
452+ quantize_ (m_fp8_dyn , config , filter_fn = _is_conv3d )
359453
360454 m_fp8_dyn = torch .compile (m_fp8_dyn )
361455
@@ -364,20 +458,22 @@ def run(
364458 fp8_trace_filename = f"{ outfile } _{ M_val } _{ K_val } _{ N_val } _fp8.json"
365459 b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , fp8_trace_filename )
366460
367- r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
368-
369461 results .append (
370462 [
371463 M_val ,
372464 K_val ,
373465 N_val ,
466+ D ,
467+ H ,
468+ W ,
469+ kernel_size ,
374470 # roofline - gemm
375471 r_bf16_gemm_time_s ,
376472 r_fp8_gemm_time_s ,
377473 # roofline - fp8 overhead
378474 r_fp8_ovhd_time_s ,
379475 # roofline - gemm + overhead, and speedup
380- r_fp8_gemm_time_s + r_fp8_ovhd_time_s ,
476+ r_fp8_gemm_and_ovhd_s ,
381477 r_speedup ,
382478 # benchmarks - gemm
383479 b_bf16_gemm_time_s ,
0 commit comments