@@ -125,7 +125,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, block_k: int) -> Op
125125
126126 config_file_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "configs" , json_file_name )
127127 if os .path .exists (config_file_path ):
128- with open (config_file_path ) as f :
128+ with open (config_file_path , "r" , encoding = "utf-8" ) as f :
129129 logger .info (
130130 "Using configuration from %s for W8A8 Block FP8 kernel." ,
131131 config_file_path ,
@@ -332,56 +332,56 @@ def is_enough_memory(x_val):
332332@benchmark_suit .perf_report (
333333 benchmark_suit .Benchmark (
334334 # argument names to use as an x-axis for the plot
335- x_names = ['B' , 'M' , 'N' , 'K' ],
335+ x_names = ["B" , "M" , "N" , "K" ],
336336 # different possible values for `x_name`
337337 x_vals = X_VALS ,
338- line_arg = ' provider' ,
338+ line_arg = " provider" ,
339339 # argument name whose value corresponds to a different line in the plot
340- # possible values for `line_arg``
341- line_vals = ['triton' ],
340+ line_vals = ["triton" ],
342341 # label name for the lines
343- line_names = [' Triton' ],
342+ line_names = [" Triton" ],
344343 # line styles
345- ylabel = [' GB/s' , ' TFlops' ], # label name for the y-axis
346- plot_name = ' matmul-performance' ,
344+ ylabel = [" GB/s" , " TFlops" ], # label name for the y-axis
345+ plot_name = " matmul-performance" ,
347346 # name for the plot. Used also as a file name for saving the plot.
348347 args = {},
349348 ))
350349def benchmark (B , M , N , K , provider ):
351- block_size = [[128 , 128 ]]
350+ assert provider == "triton"
351+
352+ block_size = [128 , 128 ]
352353
353354 torch .manual_seed (0 )
354355 factor_for_scale = 1e-2
355356 fp8_info = torch .finfo (torch .float8_e4m3fn )
356357 fp8_max , fp8_min = fp8_info .max , fp8_info .min
357358
358- A_fp32 = (torch .rand (M , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
359+ A_fp32 = (torch .rand (M , K , dtype = torch .float32 , device = "xpu" ) - 0.5 ) * 2 * fp8_max
359360 A_fp8 = A_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
360361
361- B_fp32 = (torch .rand (N , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
362+ B_fp32 = (torch .rand (N , K , dtype = torch .float32 , device = "xpu" ) - 0.5 ) * 2 * fp8_max
362363 B_fp8 = B_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
363364
364365 block_n , block_k = block_size [0 ], block_size [1 ]
365366 n_tiles = (N + block_n - 1 ) // block_n
366367 k_tiles = (K + block_k - 1 ) // block_k
367368
368- As = torch .rand (M , k_tiles , dtype = torch .float32 ) * factor_for_scale
369- Bs = torch .rand (n_tiles , k_tiles , dtype = torch .float32 ) * factor_for_scale
369+ As = torch .rand (M , k_tiles , dtype = torch .float32 , device = "xpu" ) * factor_for_scale
370+ Bs = torch .rand (n_tiles , k_tiles , dtype = torch .float32 , device = "xpu" ) * factor_for_scale
370371
371372 quantiles = [0.5 , 0.0 , 1.0 ]
372373
373- c = torch .zeros ((B , M , N ), device = 'xpu' , dtype = torch .float32 )
374- triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , c , As , Bs , block_size )
375- torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , c , As , Bs , block_size )
376- rtol = 1e-2 if c .dtype == torch .bfloat16 else 1e-3
377- benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-4 , rtol = rtol , err_msg = 'triton to torch' )
374+ triton_fn = lambda : w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
375+ torch_fn = lambda : native_w8a8_block_fp8_matmul (A_fp8 , B_fp8 , As , Bs , block_size )
376+ rtol = 1e-3
377+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-4 , rtol = rtol , err_msg = "triton to torch" )
378378 _ , min_ms , max_ms , mean_ms , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
379379
380380 tflops = lambda ms : 2 * B * M * N * K * (1e-12 ) / (ms * 1e-3 )
381- gbps = lambda ms : B * (2 * (M * K + K * N ) + 4 .0 * (M * N )) * (1e-9 ) / (ms * 1e-3 )
381+ gbps = lambda ms : B * ((M * K + K * N ) + 2 .0 * (M * N )) * (1e-9 ) / (ms * 1e-3 )
382382
383383 return (gbps (mean_ms ), gbps (max_ms ), gbps (min_ms )), (tflops (mean_ms ), tflops (max_ms ), tflops (min_ms )), cv
384384
385385
386- if __name__ == ' __main__' :
386+ if __name__ == " __main__" :
387387 benchmark .run (show_plots = False , print_data = True )
0 commit comments