Skip to content

Commit 8061054

Browse files
committed
fix bugs
1 parent 8ca57d8 commit 8061054

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ jobs:
276276
source ../../scripts/capture-hw-details.sh
277277
python ../../scripts/build_report.py $REPORTS/prefix-sums.csv $REPORTS/prefix_sums-triton-report.csv --benchmark prefix_sums --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
278278
279+
- name: Run SGLang FP8 GEMM benchmark
280+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}
281+
run: |
282+
cd benchmarks/triton_kernels_benchmark/sglang
283+
python block_fp8_matmul.py --reports $REPORTS
284+
source ../../scripts/capture-hw-details.sh
285+
python ../../scripts/build_report.py $REPORTS/block_fp8_matmul.csv $REPORTS/block_fp8_matmul-triton-report.csv --benchmark block_fp8_matmul --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
286+
279287
- name: Run micro benchmark
280288
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'micro_benchmarks') }}
281289
run: |

benchmarks/triton_kernels_benchmark/sglang/block_fp8_matmul.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, block_k: int) -> Op
125125

126126
config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
127127
if os.path.exists(config_file_path):
128-
with open(config_file_path) as f:
128+
with open(config_file_path, "r", encoding="utf-8") as f:
129129
logger.info(
130130
"Using configuration from %s for W8A8 Block FP8 kernel.",
131131
config_file_path,
@@ -332,56 +332,56 @@ def is_enough_memory(x_val):
332332
@benchmark_suit.perf_report(
333333
benchmark_suit.Benchmark(
334334
# argument names to use as an x-axis for the plot
335-
x_names=['B', 'M', 'N', 'K'],
335+
x_names=["B", "M", "N", "K"],
336336
# different possible values for `x_name`
337337
x_vals=X_VALS,
338-
line_arg='provider',
338+
line_arg="provider",
339339
# argument name whose value corresponds to a different line in the plot
340-
# possible values for `line_arg``
341-
line_vals=['triton'],
340+
line_vals=["triton"],
342341
# label name for the lines
343-
line_names=['Triton'],
342+
line_names=["Triton"],
344343
# line styles
345-
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
346-
plot_name='matmul-performance',
344+
ylabel=["GB/s", "TFlops"], # label name for the y-axis
345+
plot_name="matmul-performance",
347346
# name for the plot. Used also as a file name for saving the plot.
348347
args={},
349348
))
350349
def benchmark(B, M, N, K, provider):
351-
block_size = [[128, 128]]
350+
assert provider == "triton"
351+
352+
block_size = [128, 128]
352353

353354
torch.manual_seed(0)
354355
factor_for_scale = 1e-2
355356
fp8_info = torch.finfo(torch.float8_e4m3fn)
356357
fp8_max, fp8_min = fp8_info.max, fp8_info.min
357358

358-
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
359+
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max
359360
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
360361

361-
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
362+
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="xpu") - 0.5) * 2 * fp8_max
362363
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
363364

364365
block_n, block_k = block_size[0], block_size[1]
365366
n_tiles = (N + block_n - 1) // block_n
366367
k_tiles = (K + block_k - 1) // block_k
367368

368-
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
369-
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
369+
As = torch.rand(M, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale
370+
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="xpu") * factor_for_scale
370371

371372
quantiles = [0.5, 0.0, 1.0]
372373

373-
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
374-
triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, c, As, Bs, block_size)
375-
torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, c, As, Bs, block_size)
376-
rtol = 1e-2 if c.dtype == torch.bfloat16 else 1e-3
377-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
374+
triton_fn = lambda: w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
375+
torch_fn = lambda: native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size)
376+
rtol = 1e-3
377+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg="triton to torch")
378378
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
379379

380380
tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3)
381-
gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
381+
gbps = lambda ms: B * ((M * K + K * N) + 2.0 * (M * N)) * (1e-9) / (ms * 1e-3)
382382

383383
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
384384

385385

386-
if __name__ == '__main__':
386+
if __name__ == "__main__":
387387
benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)