Skip to content

Commit d3ce2a9

Browse files
XPUBackend refactoring to facilitate arch-specific implementations
- Allow metaclasses injection: use XPUBackendMeta from the .meta submodule as the metaclass, if it exists. - Try to find an arch-specific implementation in the .arch.<name> submodule. - Create the list of passes in separate methods to allow subclasses to modify it.
1 parent 0794e64 commit d3ce2a9

File tree

1 file changed

+96
-63
lines changed

1 file changed

+96
-63
lines changed

third_party/intel/backend/compiler.py

Lines changed: 96 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from triton.backends.compiler import BaseBackend, Language
1+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
22
from triton._C.libtriton import ir, passes, llvm, intel
33
from triton.backends.intel.driver import compile_module_from_src
44
from triton.backends.intel.track import track
@@ -15,6 +15,11 @@
1515
import subprocess
1616
from pathlib import Path
1717

18+
try: # XPUBackend allows metaclasses injection
19+
from .meta import XPUBackendMeta
20+
except ImportError:
21+
XPUBackendMeta = type(BaseBackend)
22+
1823

1924
@dataclass
2025
class XPUOptions:
@@ -63,40 +68,42 @@ def hash(self):
6368
return hashlib.sha256(key.encode("utf-8")).hexdigest()
6469

6570

66-
def min_dot_size(device_props: dict):
67-
# (M, N, K)
68-
# M: repeatCount. 1,2,4,8
69-
# N: executionSize. 16 for PVC, 8 for ATS
70-
# K: systolicDepth x opsPerChan. systolicDepth must be 8
71-
repeat_count = 1
72-
sdepth = 8
73-
exec_size = min(device_props["sub_group_sizes"])
74-
75-
def get_ops_per_channel(lhs_type, rhs_type):
76-
l_bitwidth = lhs_type.scalar.primitive_bitwidth
77-
r_bitwidth = rhs_type.scalar.primitive_bitwidth
78-
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
79-
return min(8, max_ops_per_chan)
80-
81-
return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))
82-
83-
84-
class XPUBackend(BaseBackend):
71+
class XPUBackend(BaseBackend, metaclass=XPUBackendMeta):
72+
arch_to_impl = {} # Architecture id to backend implementation class mapping
73+
binary_ext = "spv"
74+
target_arch = "spir64"
75+
device_props: dict = {}
8576
instrumentation = None
8677

8778
@staticmethod
88-
def supports_target(target: tuple):
79+
def supports_target(target: GPUTarget):
8980
return target.backend == 'xpu'
9081

91-
def __init__(self, target: tuple) -> None:
92-
super().__init__(target)
82+
def __new__(cls, target: GPUTarget):
9383
if not isinstance(target.arch, dict):
9484
raise TypeError("target.arch is not a dict")
95-
dirname = os.path.dirname(os.path.realpath(__file__))
96-
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
97-
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
85+
if cls is not XPUBackend:
86+
return super().__new__(cls)
87+
arch = target.arch.get("architecture", 0)
88+
if (impl := cls.arch_to_impl.get(arch, None)) is None:
89+
# Try to find an arch-specific implementation in the .arch.<name> submodule.
90+
if not (dev_arch := knobs.intel.device_arch):
91+
dirname = os.path.dirname(os.path.realpath(__file__))
92+
parser = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(),
93+
name="arch_utils")
94+
dev_arch = parser.parse_device_arch(target.arch.get('architecture', 0))
95+
mod_name = f"{__package__}.arch.{dev_arch}"
96+
try:
97+
impl = __import__(mod_name, fromlist=["XPUBackendImpl"]).XPUBackendImpl
98+
except ImportError:
99+
impl = type(f"{mod_name}.XPUBackendImpl", (cls, ), {})
100+
impl.device_arch = dev_arch
101+
cls.arch_to_impl[arch] = impl
102+
return super().__new__(impl)
103+
104+
def __init__(self, target: GPUTarget) -> None:
105+
super().__init__(target)
98106
self.properties = self.parse_target(target.arch)
99-
self.binary_ext = "spv"
100107

101108
def get_target_name(self, options) -> str:
102109
return f"xpu:{self.device_arch}"
@@ -120,21 +127,43 @@ def parse_target(self, tgt_prop) -> dict:
120127
dev_prop['has_subgroup_2d_block_io'] = tgt_prop.get('has_subgroup_2d_block_io', False)
121128
dev_prop['has_bfloat16_conversions'] = tgt_prop.get('has_bfloat16_conversions', True)
122129

130+
if self.device_arch in self.device_props:
131+
dev_prop.update(self.device_props[self.device_arch])
132+
return dev_prop
133+
123134
return dev_prop
124135

125136
def parse_options(self, opts) -> Any:
126-
args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts}
137+
args = {k: v for k, v in opts.items() if k in XPUOptions.__dataclass_fields__}
127138
args["allow_fp8e4nv"] = True
128139
return XPUOptions(**args)
129140

130141
def pack_metadata(self, metadata):
131142
return metadata
132143

144+
@staticmethod
145+
def min_dot_size(device_props: dict):
146+
# (M, N, K)
147+
# M: repeatCount. 1,2,4,8
148+
# N: executionSize. 16 for PVC, 8 for ATS
149+
# K: systolicDepth x opsPerChan. systolicDepth must be 8
150+
repeat_count = 1
151+
sdepth = 8
152+
exec_size = min(device_props["sub_group_sizes"])
153+
154+
def get_ops_per_channel(lhs_type, rhs_type):
155+
l_bitwidth = lhs_type.scalar.primitive_bitwidth
156+
r_bitwidth = rhs_type.scalar.primitive_bitwidth
157+
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
158+
return min(8, max_ops_per_chan)
159+
160+
return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))
161+
133162
def get_codegen_implementation(self, options):
134163
from triton.language.extra.intel import convert_custom_float8
135164
codegen_fns = {}
136165
codegen_fns["convert_custom_types"] = convert_custom_float8
137-
codegen_fns["min_dot_size"] = min_dot_size(self.properties)
166+
codegen_fns["min_dot_size"] = self.min_dot_size(self.properties)
138167
return codegen_fns
139168

140169
def get_module_map(self) -> Dict[str, ModuleType]:
@@ -143,8 +172,8 @@ def get_module_map(self) -> Dict[str, ModuleType]:
143172

144173
def load_dialects(self, ctx):
145174
intel.load_dialects(ctx)
146-
if XPUBackend.instrumentation:
147-
XPUBackend.instrumentation.load_dialects(ctx)
175+
if self.instrumentation:
176+
self.instrumentation.load_dialects(ctx)
148177

149178
@staticmethod
150179
def validate_options(opt, properties):
@@ -158,20 +187,15 @@ def validate_options(opt, properties):
158187
f"num_warps={opt.num_warps} is unsupported for the target (limit is {properties['max_num_sub_groups']})"
159188
)
160189

161-
@staticmethod
162-
def annotate_module(mod, properties, opt, target_arch):
190+
@classmethod
191+
def annotate_module(cls, module_opts, properties, opt):
163192
# Annotate module with information required by subsequent transformations.
164-
pm = ir.pass_manager(mod.context)
165-
pm.enable_debug()
166-
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
167193
module_opts.min_sg_size = min(properties["sub_group_sizes"])
168194
module_opts.support_sg_2d_block = properties["has_subgroup_2d_block_io"]
169195
module_opts.support_dpas = properties["has_subgroup_matrix_multiply_accumulate"]
170196
module_opts.support_bf16_conversion = properties["has_bfloat16_conversions"]
171197
module_opts.threads_per_warp = opt.warp_size
172-
module_opts.target_arch = target_arch
173-
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
174-
pm.run(mod, 'annotate_module')
198+
module_opts.target_arch = cls.target_arch
175199

176200
@staticmethod
177201
def get_split_barrier_scope(opt):
@@ -182,9 +206,9 @@ def get_split_barrier_scope(opt):
182206
split_barriers_scope = intel.SplitBarrierScope.Subgroup
183207
return split_barriers_scope
184208

185-
@staticmethod
209+
@classmethod
186210
@track
187-
def make_ttir(mod, metadata, opt):
211+
def make_ttir(cls, mod, metadata, opt):
188212
pm = ir.pass_manager(mod.context)
189213
pm.enable_debug()
190214
passes.common.add_inliner(pm)
@@ -204,21 +228,26 @@ def make_ttir(mod, metadata, opt):
204228
pm.run(mod, 'make_ttir')
205229
return mod
206230

207-
@staticmethod
231+
@classmethod
208232
@track
209-
def make_ttgir(mod, metadata, opt, properties):
233+
def make_ttgir(cls, mod, metadata, opt, properties):
210234
cluster_info = intel.ClusterInfo()
211235
if opt.cluster_dims is not None:
212236
cluster_info.clusterDimX = opt.cluster_dims[0]
213237
cluster_info.clusterDimY = opt.cluster_dims[1]
214238
cluster_info.clusterDimZ = opt.cluster_dims[2]
215239

216240
# Annotate module with information required by subsequent transformations.
217-
XPUBackend.annotate_module(mod, properties, opt, "spir64")
241+
pm = ir.pass_manager(mod.context)
242+
pm.enable_debug()
243+
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
244+
cls.annotate_module(module_opts, properties, opt)
245+
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
246+
pm.run(mod, 'annotate_module')
218247

219248
# Overwrite the warp_size option with the module annotation.
220249
opt.warp_size = intel.get_threads_per_warp(mod)
221-
XPUBackend.validate_options(opt, properties)
250+
cls.validate_options(opt, properties)
222251

223252
pm = ir.pass_manager(mod.context)
224253
pm.enable_debug()
@@ -278,9 +307,15 @@ def gluon_to_ttgir(self, src, metadata, options):
278307
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
279308
return mod
280309

281-
@staticmethod
310+
@classmethod
311+
def optimize_llvm_mod(cls, llvm_mod, options):
312+
intel.set_spv_target_triple(llvm_mod)
313+
with track("optimize_module") as tr:
314+
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
315+
316+
@classmethod
282317
@track
283-
def make_llir(src, metadata, options):
318+
def make_llir(cls, src, metadata, options):
284319
mod = src
285320
# TritonGPU -> LLVM-IR (MLIR)
286321
pm = ir.pass_manager(mod.context)
@@ -292,8 +327,8 @@ def make_llir(src, metadata, options):
292327
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
293328
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
294329
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
295-
if XPUBackend.instrumentation:
296-
XPUBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
330+
if cls.instrumentation:
331+
cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
297332
intel.passes.ttgpuir.add_to_llvmir(pm)
298333
intel.passes.ttgpuir.add_gen_to_llvm(pm)
299334
passes.common.add_canonicalizer(pm)
@@ -307,8 +342,8 @@ def make_llir(src, metadata, options):
307342
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
308343
passes.llvmir.add_di_scope(pm)
309344

310-
if XPUBackend.instrumentation:
311-
XPUBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
345+
if cls.instrumentation:
346+
cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
312347
pm.run(mod, 'make_llir')
313348

314349
if knobs.compilation.dump_ir_extract_di_local_variables:
@@ -333,15 +368,12 @@ def make_llir(src, metadata, options):
333368
llvm.init_targets()
334369
context = llvm.context()
335370
llvm_mod = llvm.to_module(mod, context)
336-
intel.set_spv_target_triple(llvm_mod)
337371
intel.set_fast_math(llvm_mod)
338372
if options.extern_libs:
339373
paths = [path for (name, path) in options.extern_libs]
340374
llvm.link_extern_libs(llvm_mod, paths)
341375

342-
with track("optimize_module") as tr:
343-
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
344-
376+
cls.optimize_llvm_mod(llvm_mod, options)
345377
intel.post_process_llir(llvm_mod)
346378

347379
# Get some metadata
@@ -359,9 +391,9 @@ def make_llir(src, metadata, options):
359391
del context
360392
return ret
361393

362-
@staticmethod
394+
@classmethod
363395
@track
364-
def make_spv(src, metadata, options, device_arch):
396+
def make_spv(cls, src, metadata, options):
365397
spirv, name = intel.translate_to_spirv(src)
366398
metadata["name"] = name
367399
metadata.setdefault("build_flags", "")
@@ -380,8 +412,9 @@ def make_spv(src, metadata, options, device_arch):
380412
metadata["build_flags"] += " -cl-opt-disable"
381413
return spirv
382414

383-
@staticmethod
384-
def make_zebin(src, metadata, options, device_arch):
415+
@classmethod
416+
@track
417+
def make_zebin(cls, src, metadata, options):
385418
metadata["binary_ext"] = "zebin"
386419

387420
shader_dump_opt = ""
@@ -398,8 +431,8 @@ def make_zebin(src, metadata, options, device_arch):
398431
fbin = fsrc.name + '.o'
399432

400433
ocloc_cmd = [
401-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
402-
metadata["build_flags"] + shader_dump_opt
434+
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', cls.device_arch,
435+
'-options', metadata["build_flags"] + shader_dump_opt
403436
]
404437

405438
try:
@@ -437,9 +470,9 @@ def add_stages(self, stages, options, language):
437470
elif language == Language.GLUON:
438471
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
439472
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
440-
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
473+
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options)
441474
if options.generate_native_code:
442-
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
475+
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options)
443476
if knobs.runtime.add_stages_inspection_hook is not None:
444477
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
445478

0 commit comments

Comments
 (0)