3131
3232from typing import Optional
3333
34+ DEVICE = triton .runtime .driver .active .get_active_torch_device ()
35+
3436if torch .cuda .is_available ():
3537 from triton ._C .libtriton import nvidia
3638 cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
@@ -43,6 +45,10 @@ def is_cuda():
4345 return triton .runtime .driver .active .get_current_target ().backend == "cuda"
4446
4547
48+ def is_xpu ():
49+ return triton .runtime .driver .active .get_current_target ().backend == "xpu"
50+
51+
4652def supports_tma ():
4753 return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 9
4854
@@ -51,6 +57,14 @@ def supports_ws():
5157 return is_cuda () and torch .cuda .get_device_capability ()[0 ] >= 10
5258
5359
60+ def num_sms ():
61+ if is_cuda ():
62+ return torch .cuda .get_device_properties ("cuda" ).multi_processor_count
63+ if is_xpu ():
64+ return torch .xpu .get_device_properties ("xpu" ).gpu_eu_count
65+ return 148
66+
67+
5468def _matmul_launch_metadata (grid , kernel , args ):
5569 ret = {}
5670 M , N , K , WS = args ["M" ], args ["N" ], args ["K" ], args .get ("WARP_SPECIALIZE" , False )
@@ -66,7 +80,7 @@ def _matmul_launch_metadata(grid, kernel, args):
6680
6781
6882HAS_TMA_DESC = supports_tma () and hasattr (tl , "nv_tma_desc_type" )
69- HAS_TENSOR_DESC = supports_tma () and hasattr (tl , "make_tensor_descriptor" )
83+ HAS_TENSOR_DESC = ( is_xpu () or supports_tma () ) and hasattr (tl , "make_tensor_descriptor" )
7084HAS_WARP_SPECIALIZE = supports_ws () and HAS_TENSOR_DESC
7185
7286
@@ -390,7 +404,8 @@ def matmul_persistent(a, b):
390404 # Check constraints.
391405 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions"
392406 assert a .dtype == b .dtype , "Incompatible dtypes"
393- NUM_SMS = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
407+ NUM_SMS = num_sms ()
408+
394409 M , K = a .shape
395410 K , N = b .shape
396411 dtype = a .dtype
@@ -504,7 +519,7 @@ def matmul_tma_persistent(a, b, warp_specialize: bool):
504519 desc_helper .init_tma_descriptor ("b" )
505520 desc_helper .init_tma_descriptor ("c" )
506521
507- NUM_SMS = torch . cuda . get_device_properties ( "cuda" ). multi_processor_count
522+ NUM_SMS = num_sms ()
508523
509524 def grid (META ):
510525 nonlocal desc_helper
@@ -649,11 +664,11 @@ def matmul_descriptor_persistent(a, b, warp_specialize: bool):
649664 dtype = a .dtype
650665
651666 c = torch .empty ((M , N ), device = a .device , dtype = dtype )
652- NUM_SMS = torch . cuda . get_device_properties ( "cuda" ). multi_processor_count
667+ NUM_SMS = num_sms ()
653668
654669 # TMA descriptors require a global memory allocation
655670 def alloc_fn (size : int , alignment : int , stream : Optional [int ]):
656- return torch .empty (size , device = "cuda" , dtype = torch .int8 )
671+ return torch .empty (size , device = DEVICE , dtype = torch .int8 )
657672
658673 triton .set_allocator (alloc_fn )
659674
@@ -706,17 +721,19 @@ def bench_fn(label, reps, warmup_reps, fn, *args):
706721 print (f"Benchmarking { label } : ..." , end = "" )
707722 for _ in range (warmup_reps ):
708723 fn (* args )
709- with proton_context ():
710- for _ in range (reps ):
711- fn (* args )
724+ #FIXME: Enable for XPU once proton support works.
725+ if is_cuda ():
726+ with proton_context ():
727+ for _ in range (reps ):
728+ fn (* args )
712729 print (f"\r Benchmarking { label } : done" )
713730
714731
715732def bench (K , dtype , reps = 10000 , warmup_reps = 10000 ):
716733 M = 8192
717734 N = 8192
718- a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
719- b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 ).to (dtype )
735+ a = torch .randn ((M , K ), device = DEVICE , dtype = torch .float16 ).to (dtype )
736+ b = torch .randn ((K , N ), device = DEVICE , dtype = torch .float16 ).to (dtype )
720737
721738 b = b .T .contiguous ()
722739
@@ -750,8 +767,8 @@ def run_test(expect, fn, a, b, label, enabled=True):
750767
751768def validate (M , N , K , dtype ):
752769 print (f"{ M = } , { N = } , { K = } , verification naive vs: " )
753- a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 ).to (dtype )
754- b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 ).to (dtype )
770+ a = torch .randn ((M , K ), device = DEVICE , dtype = torch .float16 ).to (dtype )
771+ b = torch .randn ((K , N ), device = DEVICE , dtype = torch .float16 ).to (dtype )
755772 b = b .T .contiguous ()
756773
757774 naive_result = matmul (a , b .T ).to (torch .float16 )
@@ -806,10 +823,11 @@ def show_profile(precision, profile_name):
806823
807824 validate (32 , 32 , 32 , dtype )
808825 validate (8192 , 8192 , args .K_range [0 ], dtype )
809-
810- proton .start ("matmul" , hook = "triton" )
811- proton .deactivate ()
826+ if is_cuda ():
827+ proton .start ("matmul" , hook = "triton" )
828+ proton .deactivate ()
812829 for K in range (args .K_range [0 ], args .K_range [1 ] + 1 , args .K_step ):
813830 bench (K , dtype )
814- proton .finalize ()
815- show_profile (args .prec , "matmul" )
831+ if is_cuda ():
832+ proton .finalize ()
833+ show_profile (args .prec , "matmul" )
0 commit comments