Skip to content

Commit ed88733

Browse files
authored
Port and enable tutorial 09-persistent-matmul (#3833)
Enable the tutorial `09-persistent-matmul` and run it with both tensor of ptrs (`matmul_kernel_persistent`) and with tensor descriptors (`matmul_kernel_descriptor_persistent`). The latter kernel is internally translated into blocked ptrs. Note: rebase after PRs #3817 and #3832 lands. --------- Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
1 parent a645533 commit ed88733

File tree

5 files changed

+37
-19
lines changed

5 files changed

+37
-19
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
from typing import Optional
3333

34+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
35+
3436
if torch.cuda.is_available():
3537
from triton._C.libtriton import nvidia
3638
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
@@ -43,6 +45,10 @@ def is_cuda():
4345
return triton.runtime.driver.active.get_current_target().backend == "cuda"
4446

4547

48+
def is_xpu():
49+
return triton.runtime.driver.active.get_current_target().backend == "xpu"
50+
51+
4652
def supports_tma():
4753
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
4854

@@ -51,6 +57,14 @@ def supports_ws():
5157
return is_cuda() and torch.cuda.get_device_capability()[0] >= 10
5258

5359

60+
def num_sms():
61+
if is_cuda():
62+
return torch.cuda.get_device_properties("cuda").multi_processor_count
63+
if is_xpu():
64+
return torch.xpu.get_device_properties("xpu").gpu_eu_count
65+
return 148
66+
67+
5468
def _matmul_launch_metadata(grid, kernel, args):
5569
ret = {}
5670
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
@@ -66,7 +80,7 @@ def _matmul_launch_metadata(grid, kernel, args):
6680

6781

6882
HAS_TMA_DESC = supports_tma() and hasattr(tl, "nv_tma_desc_type")
69-
HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor")
83+
HAS_TENSOR_DESC = (is_xpu() or supports_tma()) and hasattr(tl, "make_tensor_descriptor")
7084
HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC
7185

7286

@@ -390,7 +404,8 @@ def matmul_persistent(a, b):
390404
# Check constraints.
391405
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
392406
assert a.dtype == b.dtype, "Incompatible dtypes"
393-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
407+
NUM_SMS = num_sms()
408+
394409
M, K = a.shape
395410
K, N = b.shape
396411
dtype = a.dtype
@@ -504,7 +519,7 @@ def matmul_tma_persistent(a, b, warp_specialize: bool):
504519
desc_helper.init_tma_descriptor("b")
505520
desc_helper.init_tma_descriptor("c")
506521

507-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
522+
NUM_SMS = num_sms()
508523

509524
def grid(META):
510525
nonlocal desc_helper
@@ -649,11 +664,11 @@ def matmul_descriptor_persistent(a, b, warp_specialize: bool):
649664
dtype = a.dtype
650665

651666
c = torch.empty((M, N), device=a.device, dtype=dtype)
652-
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
667+
NUM_SMS = num_sms()
653668

654669
# TMA descriptors require a global memory allocation
655670
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
656-
return torch.empty(size, device="cuda", dtype=torch.int8)
671+
return torch.empty(size, device=DEVICE, dtype=torch.int8)
657672

658673
triton.set_allocator(alloc_fn)
659674

@@ -706,17 +721,19 @@ def bench_fn(label, reps, warmup_reps, fn, *args):
706721
print(f"Benchmarking {label}: ...", end="")
707722
for _ in range(warmup_reps):
708723
fn(*args)
709-
with proton_context():
710-
for _ in range(reps):
711-
fn(*args)
724+
#FIXME: Enable for XPU once proton support works.
725+
if is_cuda():
726+
with proton_context():
727+
for _ in range(reps):
728+
fn(*args)
712729
print(f"\rBenchmarking {label}: done")
713730

714731

715732
def bench(K, dtype, reps=10000, warmup_reps=10000):
716733
M = 8192
717734
N = 8192
718-
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
719-
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
735+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype)
736+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype)
720737

721738
b = b.T.contiguous()
722739

@@ -750,8 +767,8 @@ def run_test(expect, fn, a, b, label, enabled=True):
750767

751768
def validate(M, N, K, dtype):
752769
print(f"{M=}, {N=}, {K=}, verification naive vs: ")
753-
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
754-
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
770+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16).to(dtype)
771+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16).to(dtype)
755772
b = b.T.contiguous()
756773

757774
naive_result = matmul(a, b.T).to(torch.float16)
@@ -806,10 +823,11 @@ def show_profile(precision, profile_name):
806823

807824
validate(32, 32, 32, dtype)
808825
validate(8192, 8192, args.K_range[0], dtype)
809-
810-
proton.start("matmul", hook="triton")
811-
proton.deactivate()
826+
if is_cuda():
827+
proton.start("matmul", hook="triton")
828+
proton.deactivate()
812829
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
813830
bench(K, dtype)
814-
proton.finalize()
815-
show_profile(args.prec, "matmul")
831+
if is_cuda():
832+
proton.finalize()
833+
show_profile(args.prec, "matmul")

scripts/skiplist/lts/tutorials.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
03-matrix-multiplication
22
06-fused-attention
33
08-grouped-gemm
4+
09-persistent-matmul
45
10-experimental-block-pointer
56
10i-experimental-block-pointer

scripts/test-triton.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ run_tutorial_tests() {
252252
run_tutorial_test "06-fused-attention"
253253
run_tutorial_test "07-extern-functions"
254254
run_tutorial_test "08-grouped-gemm"
255+
TRITON_TEST_REPORTS=false run_tutorial_test "09-persistent-matmul"
255256
run_tutorial_test "10-experimental-block-pointer"
256257
run_tutorial_test "10i-experimental-block-pointer"
257258

test/TritonIntelGPU/combine.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
24102410
// COM: Reproducer for issue #3817 (to ensure that the compiler doesn't crash).
24112411

24122412
// CHECK: #[[$BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2413-
24142413
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
24152414
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
24162415
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,6 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12841284
auto maskOrder = linAttr.getOrder();
12851285
if (maskOrder[0] >= axisInfo->getRank())
12861286
return 1;
1287-
12881287
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12891288
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12901289
<< alignment);

0 commit comments

Comments
 (0)