1- from triton .backends .compiler import BaseBackend , Language
1+ from triton .backends .compiler import BaseBackend , GPUTarget , Language
22from triton ._C .libtriton import ir , passes , llvm , intel
33from triton .backends .intel .driver import compile_module_from_src
44from triton .backends .intel .track import track
1515import subprocess
1616from pathlib import Path
1717
18+ try : # XPUBackend allows metaclasses injection
19+ from .meta import XPUBackendMeta
20+ except ImportError :
21+ XPUBackendMeta = type (BaseBackend )
22+
1823
1924@dataclass
2025class XPUOptions :
@@ -63,40 +68,42 @@ def hash(self):
6368 return hashlib .sha256 (key .encode ("utf-8" )).hexdigest ()
6469
6570
66- def min_dot_size (device_props : dict ):
67- # (M, N, K)
68- # M: repeatCount. 1,2,4,8
69- # N: executionSize. 16 for PVC, 8 for ATS
70- # K: systolicDepth x opsPerChan. systolicDepth must be 8
71- repeat_count = 1
72- sdepth = 8
73- exec_size = min (device_props ["sub_group_sizes" ])
74-
75- def get_ops_per_channel (lhs_type , rhs_type ):
76- l_bitwidth = lhs_type .scalar .primitive_bitwidth
77- r_bitwidth = rhs_type .scalar .primitive_bitwidth
78- max_ops_per_chan = 32 / max (l_bitwidth , r_bitwidth )
79- return min (8 , max_ops_per_chan )
80-
81- return lambda lhs_type , rhs_type : (repeat_count , exec_size , sdepth * get_ops_per_channel (lhs_type , rhs_type ))
82-
83-
84- class XPUBackend (BaseBackend ):
71+ class XPUBackend (BaseBackend , metaclass = XPUBackendMeta ):
72+ arch_to_impl = {} # Architecture id to backend implementation class mapping
73+ binary_ext = "spv"
74+ target_arch = "spir64"
75+ device_props : dict = {}
8576 instrumentation = None
8677
8778 @staticmethod
88- def supports_target (target : tuple ):
79+ def supports_target (target : GPUTarget ):
8980 return target .backend == 'xpu'
9081
91- def __init__ (self , target : tuple ) -> None :
92- super ().__init__ (target )
82+ def __new__ (cls , target : GPUTarget ):
9383 if not isinstance (target .arch , dict ):
9484 raise TypeError ("target.arch is not a dict" )
95- dirname = os .path .dirname (os .path .realpath (__file__ ))
96- mod = compile_module_from_src (src = Path (os .path .join (dirname , "arch_parser.c" )).read_text (), name = "arch_utils" )
97- self .device_arch = knobs .intel .device_arch or mod .parse_device_arch (target .arch .get ('architecture' , 0 ))
85+ if cls is not XPUBackend :
86+ return super ().__new__ (cls )
87+ arch = target .arch .get ("architecture" , 0 )
88+ if (impl := cls .arch_to_impl .get (arch , None )) is None :
89+ # Try to find an arch-specific implementation in the .arch.<name> submodule.
90+ if not (dev_arch := knobs .intel .device_arch ):
91+ dirname = os .path .dirname (os .path .realpath (__file__ ))
92+ parser = compile_module_from_src (src = Path (os .path .join (dirname , "arch_parser.c" )).read_text (),
93+ name = "arch_utils" )
94+ dev_arch = parser .parse_device_arch (target .arch .get ('architecture' , 0 ))
95+ mod_name = f"{ __package__ } .arch.{ dev_arch } "
96+ try :
97+ impl = __import__ (mod_name , fromlist = ["XPUBackendImpl" ]).XPUBackendImpl
98+ except ImportError :
99+ impl = type (f"{ mod_name } .XPUBackendImpl" , (cls , ), {})
100+ impl .device_arch = dev_arch
101+ cls .arch_to_impl [arch ] = impl
102+ return super ().__new__ (impl )
103+
104+ def __init__ (self , target : GPUTarget ) -> None :
105+ super ().__init__ (target )
98106 self .properties = self .parse_target (target .arch )
99- self .binary_ext = "spv"
100107
101108 def get_target_name (self , options ) -> str :
102109 return f"xpu:{ self .device_arch } "
@@ -120,21 +127,43 @@ def parse_target(self, tgt_prop) -> dict:
120127 dev_prop ['has_subgroup_2d_block_io' ] = tgt_prop .get ('has_subgroup_2d_block_io' , False )
121128 dev_prop ['has_bfloat16_conversions' ] = tgt_prop .get ('has_bfloat16_conversions' , True )
122129
130+ if self .device_arch in self .device_props :
131+ dev_prop .update (self .device_props [self .device_arch ])
132+ return dev_prop
133+
123134 return dev_prop
124135
125136 def parse_options (self , opts ) -> Any :
126- args = {k : opts [ k ] for k in XPUOptions . __dataclass_fields__ . keys () if k in opts }
137+ args = {k : v for k , v in opts . items () if k in XPUOptions . __dataclass_fields__ }
127138 args ["allow_fp8e4nv" ] = True
128139 return XPUOptions (** args )
129140
130141 def pack_metadata (self , metadata ):
131142 return metadata
132143
144+ @staticmethod
145+ def min_dot_size (device_props : dict ):
146+ # (M, N, K)
147+ # M: repeatCount. 1,2,4,8
148+ # N: executionSize. 16 for PVC, 8 for ATS
149+ # K: systolicDepth x opsPerChan. systolicDepth must be 8
150+ repeat_count = 1
151+ sdepth = 8
152+ exec_size = min (device_props ["sub_group_sizes" ])
153+
154+ def get_ops_per_channel (lhs_type , rhs_type ):
155+ l_bitwidth = lhs_type .scalar .primitive_bitwidth
156+ r_bitwidth = rhs_type .scalar .primitive_bitwidth
157+ max_ops_per_chan = 32 / max (l_bitwidth , r_bitwidth )
158+ return min (8 , max_ops_per_chan )
159+
160+ return lambda lhs_type , rhs_type : (repeat_count , exec_size , sdepth * get_ops_per_channel (lhs_type , rhs_type ))
161+
133162 def get_codegen_implementation (self , options ):
134163 from triton .language .extra .intel import convert_custom_float8
135164 codegen_fns = {}
136165 codegen_fns ["convert_custom_types" ] = convert_custom_float8
137- codegen_fns ["min_dot_size" ] = min_dot_size (self .properties )
166+ codegen_fns ["min_dot_size" ] = self . min_dot_size (self .properties )
138167 return codegen_fns
139168
140169 def get_module_map (self ) -> Dict [str , ModuleType ]:
@@ -143,8 +172,8 @@ def get_module_map(self) -> Dict[str, ModuleType]:
143172
144173 def load_dialects (self , ctx ):
145174 intel .load_dialects (ctx )
146- if XPUBackend .instrumentation :
147- XPUBackend .instrumentation .load_dialects (ctx )
175+ if self .instrumentation :
176+ self .instrumentation .load_dialects (ctx )
148177
149178 @staticmethod
150179 def validate_options (opt , properties ):
@@ -158,20 +187,15 @@ def validate_options(opt, properties):
158187 f"num_warps={ opt .num_warps } is unsupported for the target (limit is { properties ['max_num_sub_groups' ]} )"
159188 )
160189
161- @staticmethod
162- def annotate_module (mod , properties , opt , target_arch ):
190+ @classmethod
191+ def annotate_module (cls , module_opts , properties , opt ):
163192 # Annotate module with information required by subsequent transformations.
164- pm = ir .pass_manager (mod .context )
165- pm .enable_debug ()
166- module_opts = intel .passes .ttgpuir .AnnotateModuleOptions ()
167193 module_opts .min_sg_size = min (properties ["sub_group_sizes" ])
168194 module_opts .support_sg_2d_block = properties ["has_subgroup_2d_block_io" ]
169195 module_opts .support_dpas = properties ["has_subgroup_matrix_multiply_accumulate" ]
170196 module_opts .support_bf16_conversion = properties ["has_bfloat16_conversions" ]
171197 module_opts .threads_per_warp = opt .warp_size
172- module_opts .target_arch = target_arch
173- intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
174- pm .run (mod , 'annotate_module' )
198+ module_opts .target_arch = cls .target_arch
175199
176200 @staticmethod
177201 def get_split_barrier_scope (opt ):
@@ -182,9 +206,9 @@ def get_split_barrier_scope(opt):
182206 split_barriers_scope = intel .SplitBarrierScope .Subgroup
183207 return split_barriers_scope
184208
185- @staticmethod
209+ @classmethod
186210 @track
187- def make_ttir (mod , metadata , opt ):
211+ def make_ttir (cls , mod , metadata , opt ):
188212 pm = ir .pass_manager (mod .context )
189213 pm .enable_debug ()
190214 passes .common .add_inliner (pm )
@@ -204,21 +228,26 @@ def make_ttir(mod, metadata, opt):
204228 pm .run (mod , 'make_ttir' )
205229 return mod
206230
207- @staticmethod
231+ @classmethod
208232 @track
209- def make_ttgir (mod , metadata , opt , properties ):
233+ def make_ttgir (cls , mod , metadata , opt , properties ):
210234 cluster_info = intel .ClusterInfo ()
211235 if opt .cluster_dims is not None :
212236 cluster_info .clusterDimX = opt .cluster_dims [0 ]
213237 cluster_info .clusterDimY = opt .cluster_dims [1 ]
214238 cluster_info .clusterDimZ = opt .cluster_dims [2 ]
215239
216240 # Annotate module with information required by subsequent transformations.
217- XPUBackend .annotate_module (mod , properties , opt , "spir64" )
241+ pm = ir .pass_manager (mod .context )
242+ pm .enable_debug ()
243+ module_opts = intel .passes .ttgpuir .AnnotateModuleOptions ()
244+ cls .annotate_module (module_opts , properties , opt )
245+ intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
246+ pm .run (mod , 'annotate_module' )
218247
219248 # Overwrite the warp_size option with the module annotation.
220249 opt .warp_size = intel .get_threads_per_warp (mod )
221- XPUBackend .validate_options (opt , properties )
250+ cls .validate_options (opt , properties )
222251
223252 pm = ir .pass_manager (mod .context )
224253 pm .enable_debug ()
@@ -278,9 +307,15 @@ def gluon_to_ttgir(self, src, metadata, options):
278307 metadata ["tensordesc_meta" ] = mod .get_tensordesc_metadata ()
279308 return mod
280309
281- @staticmethod
310+ @classmethod
311+ def optimize_llvm_mod (cls , llvm_mod , options ):
312+ intel .set_spv_target_triple (llvm_mod )
313+ with track ("optimize_module" ) as tr :
314+ intel .optimize_module (llvm_mod , llvm .OPTIMIZE_O3 , tr .callback ("passes" ))
315+
316+ @classmethod
282317 @track
283- def make_llir (src , metadata , options ):
318+ def make_llir (cls , src , metadata , options ):
284319 mod = src
285320 # TritonGPU -> LLVM-IR (MLIR)
286321 pm = ir .pass_manager (mod .context )
@@ -292,8 +327,8 @@ def make_llir(src, metadata, options):
292327 intel .passes .ttgpuir .add_allocate_shared_memory (pm )
293328 passes .ttgpuir .add_allocate_global_scratch_memory (pm )
294329 # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
295- if XPUBackend .instrumentation :
296- XPUBackend .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
330+ if cls .instrumentation :
331+ cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
297332 intel .passes .ttgpuir .add_to_llvmir (pm )
298333 intel .passes .ttgpuir .add_gen_to_llvm (pm )
299334 passes .common .add_canonicalizer (pm )
@@ -307,8 +342,8 @@ def make_llir(src, metadata, options):
307342 if not knobs .compilation .disable_line_info and not knobs .compilation .dump_ir_extract_di_local_variables :
308343 passes .llvmir .add_di_scope (pm )
309344
310- if XPUBackend .instrumentation :
311- XPUBackend .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
345+ if cls .instrumentation :
346+ cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
312347 pm .run (mod , 'make_llir' )
313348
314349 if knobs .compilation .dump_ir_extract_di_local_variables :
@@ -333,15 +368,12 @@ def make_llir(src, metadata, options):
333368 llvm .init_targets ()
334369 context = llvm .context ()
335370 llvm_mod = llvm .to_module (mod , context )
336- intel .set_spv_target_triple (llvm_mod )
337371 intel .set_fast_math (llvm_mod )
338372 if options .extern_libs :
339373 paths = [path for (name , path ) in options .extern_libs ]
340374 llvm .link_extern_libs (llvm_mod , paths )
341375
342- with track ("optimize_module" ) as tr :
343- intel .optimize_module (llvm_mod , llvm .OPTIMIZE_O3 , tr .callback ("passes" ))
344-
376+ cls .optimize_llvm_mod (llvm_mod , options )
345377 intel .post_process_llir (llvm_mod )
346378
347379 # Get some metadata
@@ -359,9 +391,9 @@ def make_llir(src, metadata, options):
359391 del context
360392 return ret
361393
362- @staticmethod
394+ @classmethod
363395 @track
364- def make_spv (src , metadata , options , device_arch ):
396+ def make_spv (cls , src , metadata , options ):
365397 spirv , name = intel .translate_to_spirv (src )
366398 metadata ["name" ] = name
367399 metadata .setdefault ("build_flags" , "" )
@@ -380,8 +412,9 @@ def make_spv(src, metadata, options, device_arch):
380412 metadata ["build_flags" ] += " -cl-opt-disable"
381413 return spirv
382414
383- @staticmethod
384- def make_zebin (src , metadata , options , device_arch ):
415+ @classmethod
416+ @track
417+ def make_zebin (cls , src , metadata , options ):
385418 metadata ["binary_ext" ] = "zebin"
386419
387420 shader_dump_opt = ""
@@ -398,8 +431,8 @@ def make_zebin(src, metadata, options, device_arch):
398431 fbin = fsrc .name + '.o'
399432
400433 ocloc_cmd = [
401- 'ocloc' , 'compile' , '-file' , fsrc .name , '-o' , fbin , '-spirv_input' , '-device' , device_arch , '-options' ,
402- metadata ["build_flags" ] + shader_dump_opt
434+ 'ocloc' , 'compile' , '-file' , fsrc .name , '-o' , fbin , '-spirv_input' , '-device' , cls . device_arch ,
435+ '-options' , metadata ["build_flags" ] + shader_dump_opt
403436 ]
404437
405438 try :
@@ -437,9 +470,9 @@ def add_stages(self, stages, options, language):
437470 elif language == Language .GLUON :
438471 stages ["ttgir" ] = lambda src , metadata : self .gluon_to_ttgir (src , metadata , options )
439472 stages ["llir" ] = lambda src , metadata : self .make_llir (src , metadata , options )
440- stages ["spv" ] = lambda src , metadata : self .make_spv (src , metadata , options , self . device_arch )
473+ stages ["spv" ] = lambda src , metadata : self .make_spv (src , metadata , options )
441474 if options .generate_native_code :
442- stages ["zebin" ] = lambda src , metadata : self .make_zebin (src , metadata , options , self . device_arch )
475+ stages ["zebin" ] = lambda src , metadata : self .make_zebin (src , metadata , options )
443476 if knobs .runtime .add_stages_inspection_hook is not None :
444477 knobs .runtime .add_stages_inspection_hook (self , stages , options , language , None )
445478
0 commit comments