Skip to content

Commit e29cd0a

Browse files
committed
feat: put the tt func in a separate module and use symbol ref
1 parent 013180c commit e29cd0a

File tree

2 files changed

+87
-78
lines changed

2 files changed

+87
-78
lines changed

src/Compiler.jl

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,42 +1317,42 @@ function triton_optimization_passes()
13171317
"convert-nvvm-to-llvm",
13181318
# common passes
13191319
"canonicalize",
1320-
# # ttir passes
1321-
# "triton-combine",
1322-
# "triton-reorder-broadcast",
1323-
# "triton-rewrite-tensor-pointer",
1324-
# "triton-rewrite-tensor-descriptor-to-pointer",
1325-
# "triton-loop-unroll",
1326-
# "triton-licm",
1327-
# "triton-loop-aware-cse",
1328-
# # TODO: should num-warps and num-ctas be set for each kernel?
1329-
# "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330-
# # ttgir passes
1331-
# "tritongpu-coalesce",
1332-
# "tritongpu-optimize-thread-locality",
1333-
# "tritongpu-hoist-tmem-alloc",
1334-
# "tritongpu-assign-latencies",
1335-
# "tritongpu-pipeline",
1336-
# "tritongpu-schedule-loops",
1337-
# "tritongpu-automatic-warp-specialization",
1338-
# "tritongpu-prefetch",
1339-
# "tritongpu-accelerate-matmul",
1340-
# "tritongpu-reorder-instructions",
1341-
# "tritongpu-F32DotTC",
1342-
# "tritongpu-optimize-dot-operands",
1343-
# "tritongpu-remove-layout-conversions",
1344-
# "tritongpu-reduce-data-duplication",
1345-
# "tritongpu-hoist-tmem-alloc",
1346-
# "tritongpu-fuse-nested-loops",
1347-
# "tritongpu-rewrite-partition-dependencies",
1348-
# "tritongpu-partition-loops",
1349-
# "tritongpu-combine-tensor-select-and-if",
1350-
# # ttgir to llvm passes
1351-
# "tritongpu-allocate-warp-groups",
1352-
# "allocate-shared-memory",
1353-
# "tritongpu-global-scratch-memory-allocation",
1354-
# "tritongpu-optimize-accumulator-init",
1355-
# "tritongpu-coalesce-async-copy",
1320+
# ttir passes
1321+
"triton-combine",
1322+
"triton-reorder-broadcast",
1323+
"triton-rewrite-tensor-pointer",
1324+
"triton-rewrite-tensor-descriptor-to-pointer",
1325+
"triton-loop-unroll",
1326+
"triton-licm",
1327+
"triton-loop-aware-cse",
1328+
# TODO: should num-warps and num-ctas be set for each kernel?
1329+
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330+
# ttgir passes
1331+
"tritongpu-coalesce",
1332+
"tritongpu-optimize-thread-locality",
1333+
"tritongpu-hoist-tmem-alloc",
1334+
"tritongpu-assign-latencies",
1335+
"tritongpu-pipeline",
1336+
"tritongpu-schedule-loops",
1337+
"tritongpu-automatic-warp-specialization",
1338+
"tritongpu-prefetch",
1339+
"tritongpu-accelerate-matmul",
1340+
"tritongpu-reorder-instructions",
1341+
"tritongpu-F32DotTC",
1342+
"tritongpu-optimize-dot-operands",
1343+
"tritongpu-remove-layout-conversions",
1344+
"tritongpu-reduce-data-duplication",
1345+
"tritongpu-hoist-tmem-alloc",
1346+
"tritongpu-fuse-nested-loops",
1347+
"tritongpu-rewrite-partition-dependencies",
1348+
"tritongpu-partition-loops",
1349+
"tritongpu-combine-tensor-select-and-if",
1350+
# ttgir to llvm passes
1351+
"tritongpu-allocate-warp-groups",
1352+
"allocate-shared-memory",
1353+
"tritongpu-global-scratch-memory-allocation",
1354+
"tritongpu-optimize-accumulator-init",
1355+
"tritongpu-coalesce-async-copy",
13561356
],
13571357
",",
13581358
)
@@ -2311,8 +2311,7 @@ function compile_mlir!(
23112311
end
23122312
end
23132313

2314-
# XXX: re-enable this pass
2315-
# run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
2314+
run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
23162315

23172316
func_op = MLIR.API.mlirSymbolTableLookup(
23182317
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname

src/Ops.jl

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,54 +1750,60 @@ function _extract_function(
17501750
func_name::String="main",
17511751
func_op_kind::String="func.func",
17521752
nested_module::Bool=false,
1753+
location::MLIR.IR.Location=MLIR.IR.Location(),
17531754
)
17541755
module_suffix = string(hash(code); base=16)
1755-
name_to_call = _new_function_name(func_name, module_suffix)
1756+
name_to_call = func_name * "_call_" * module_suffix
1757+
mod_name = func_name * "_module_" * module_suffix
1758+
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
17561759

1757-
current_module = MLIR.IR.mmodule()
17581760
if nested_module
1759-
new_module = MLIR.IR.Module()
1760-
push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true))
1761-
current_module = new_module
1762-
end
1763-
top_level_block = MLIR.IR.body(current_module)
1761+
region = MLIR.IR.Region()
1762+
push!(region, MLIR.IR.Block())
1763+
moduleop = MLIR.Dialects.builtin.module_(;
1764+
location, bodyRegion=region, sym_name=mod_name
1765+
)
1766+
MLIR.IR.rmfromparent!(moduleop)
1767+
push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module
17641768

1765-
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
1766-
fn = MLIR.IR.lookup(
1767-
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
1768-
)
1769+
top_level_block = MLIR.IR.Block(
1770+
MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false
1771+
)
1772+
fn = nothing
1773+
else
1774+
current_module = MLIR.IR.mmodule()
1775+
moduleop = MLIR.IR.Operation(current_module)
1776+
top_level_block = MLIR.IR.body(current_module)
1777+
fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call)
1778+
end
17691779

17701780
if isnothing(fn)
17711781
new_mod = parse(MLIR.IR.Module, code)
17721782
new_mod_op = MLIR.IR.Operation(new_mod)
17731783
body = MLIR.IR.body(new_mod)
17741784

17751785
operations = collect(MLIR.IR.OperationIterator(body))
1776-
for op in operations
1777-
if MLIR.IR.name(op) == func_op_kind
1778-
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
1779-
if fn_name == func_name
1780-
fn = op
1781-
end
1786+
idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations)
1787+
@assert idx !== nothing
1788+
op = operations[idx]
17821789

1783-
res = MLIR.IR.LogicalResult(
1784-
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
1785-
fn_name, name_to_call, new_mod_op
1786-
),
1787-
)
1788-
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
1789-
1790-
# Set function private
1791-
MLIR.IR.attr!(
1792-
op,
1793-
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1794-
MLIR.IR.Attribute("private"),
1795-
)
1796-
1797-
# Change function name
1798-
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
1799-
end
1800-
end
1790+
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
1791+
fn_name == func_name && (fn = op)
1792+
1793+
res = MLIR.IR.LogicalResult(
1794+
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op)
1795+
)
1796+
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
1797+
1798+
# Set function private
1799+
MLIR.IR.attr!(
1800+
op,
1801+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1802+
MLIR.IR.Attribute("private"),
1803+
)
1804+
1805+
# Change function name
1806+
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
18011807

18021808
for op in operations
18031809
MLIR.IR.rmfromparent!(op)
@@ -1809,7 +1815,7 @@ function _extract_function(
18091815
error("hlo_call: could not find function $func_name in the provided module")
18101816
end
18111817

1812-
return fn, name_to_call
1818+
return fn, name_to_call, mod_name
18131819
end
18141820

18151821
function triton_call(
@@ -1823,8 +1829,8 @@ function triton_call(
18231829
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18241830
# TODO: other kwargs
18251831
)
1826-
_, name_to_call = _extract_function(
1827-
mlir_code; func_name, func_op_kind="tt.func", nested_module=true
1832+
_, name_to_call, mod_name = _extract_function(
1833+
mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location
18281834
)
18291835

18301836
enzymexla.triton_call(
@@ -1833,7 +1839,9 @@ function triton_call(
18331839
grid_z.mlir_data,
18341840
shmem.mlir_data,
18351841
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1836-
fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1842+
fn=MLIR.IR.SymbolRefAttribute(
1843+
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1844+
),
18371845
result_0=MLIR.IR.Type[],
18381846
location,
18391847
)
@@ -1871,7 +1879,9 @@ julia> Reactant.@jit(
18711879
func_name="main",
18721880
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18731881
)
1874-
fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1882+
fn, name_to_call, _ = _extract_function(
1883+
code; func_name, func_op_kind="func.func", location
1884+
)
18751885

18761886
ftype_attr = MLIR.IR.attr(fn, "function_type")
18771887
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)