Skip to content

Commit 7f0afd8

Browse files
committed
feat: put the tt func in a separate module and use symbol ref
1 parent 5b37e06 commit 7f0afd8

File tree

3 files changed

+88
-79
lines changed

3 files changed

+88
-79
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "b59185c7586783a17d9486e682307ae89c713964"
7+
ENZYMEXLA_COMMIT = "52ae936cae8f7050adc26c4ed5e755200497dc86"
88

99
ENZYMEXLA_SHA256 = ""
1010

src/Compiler.jl

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,42 +1310,42 @@ function triton_optimization_passes()
13101310
"convert-nvvm-to-llvm",
13111311
# common passes
13121312
"canonicalize",
1313-
# # ttir passes
1314-
# "triton-combine",
1315-
# "triton-reorder-broadcast",
1316-
# "triton-rewrite-tensor-pointer",
1317-
# "triton-rewrite-tensor-descriptor-to-pointer",
1318-
# "triton-loop-unroll",
1319-
# "triton-licm",
1320-
# "triton-loop-aware-cse",
1321-
# # TODO: should num-warps and num-ctas be set for each kernel?
1322-
# "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323-
# # ttgir passes
1324-
# "tritongpu-coalesce",
1325-
# "tritongpu-optimize-thread-locality",
1326-
# "tritongpu-hoist-tmem-alloc",
1327-
# "tritongpu-assign-latencies",
1328-
# "tritongpu-pipeline",
1329-
# "tritongpu-schedule-loops",
1330-
# "tritongpu-automatic-warp-specialization",
1331-
# "tritongpu-prefetch",
1332-
# "tritongpu-accelerate-matmul",
1333-
# "tritongpu-reorder-instructions",
1334-
# "tritongpu-F32DotTC",
1335-
# "tritongpu-optimize-dot-operands",
1336-
# "tritongpu-remove-layout-conversions",
1337-
# "tritongpu-reduce-data-duplication",
1338-
# "tritongpu-hoist-tmem-alloc",
1339-
# "tritongpu-fuse-nested-loops",
1340-
# "tritongpu-rewrite-partition-dependencies",
1341-
# "tritongpu-partition-loops",
1342-
# "tritongpu-combine-tensor-select-and-if",
1343-
# # ttgir to llvm passes
1344-
# "tritongpu-allocate-warp-groups",
1345-
# "allocate-shared-memory",
1346-
# "tritongpu-global-scratch-memory-allocation",
1347-
# "tritongpu-optimize-accumulator-init",
1348-
# "tritongpu-coalesce-async-copy",
1313+
# ttir passes
1314+
"triton-combine",
1315+
"triton-reorder-broadcast",
1316+
"triton-rewrite-tensor-pointer",
1317+
"triton-rewrite-tensor-descriptor-to-pointer",
1318+
"triton-loop-unroll",
1319+
"triton-licm",
1320+
"triton-loop-aware-cse",
1321+
# TODO: should num-warps and num-ctas be set for each kernel?
1322+
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323+
# ttgir passes
1324+
"tritongpu-coalesce",
1325+
"tritongpu-optimize-thread-locality",
1326+
"tritongpu-hoist-tmem-alloc",
1327+
"tritongpu-assign-latencies",
1328+
"tritongpu-pipeline",
1329+
"tritongpu-schedule-loops",
1330+
"tritongpu-automatic-warp-specialization",
1331+
"tritongpu-prefetch",
1332+
"tritongpu-accelerate-matmul",
1333+
"tritongpu-reorder-instructions",
1334+
"tritongpu-F32DotTC",
1335+
"tritongpu-optimize-dot-operands",
1336+
"tritongpu-remove-layout-conversions",
1337+
"tritongpu-reduce-data-duplication",
1338+
"tritongpu-hoist-tmem-alloc",
1339+
"tritongpu-fuse-nested-loops",
1340+
"tritongpu-rewrite-partition-dependencies",
1341+
"tritongpu-partition-loops",
1342+
"tritongpu-combine-tensor-select-and-if",
1343+
# ttgir to llvm passes
1344+
"tritongpu-allocate-warp-groups",
1345+
"allocate-shared-memory",
1346+
"tritongpu-global-scratch-memory-allocation",
1347+
"tritongpu-optimize-accumulator-init",
1348+
"tritongpu-coalesce-async-copy",
13491349
],
13501350
",",
13511351
)
@@ -2303,8 +2303,7 @@ function compile_mlir!(
23032303
end
23042304
end
23052305

2306-
# XXX: re-enable this pass
2307-
# run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
2306+
run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
23082307

23092308
func_op = MLIR.API.mlirSymbolTableLookup(
23102309
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname

src/Ops.jl

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,54 +1705,60 @@ function _extract_function(
17051705
func_name::String="main",
17061706
func_op_kind::String="func.func",
17071707
nested_module::Bool=false,
1708+
location::MLIR.IR.Location=MLIR.IR.Location(),
17081709
)
17091710
module_suffix = string(hash(code); base=16)
1710-
name_to_call = _new_function_name(func_name, module_suffix)
1711+
name_to_call = func_name * "_call_" * module_suffix
1712+
mod_name = func_name * "_module_" * module_suffix
1713+
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
17111714

1712-
current_module = MLIR.IR.mmodule()
17131715
if nested_module
1714-
new_module = MLIR.IR.Module()
1715-
push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true))
1716-
current_module = new_module
1717-
end
1718-
top_level_block = MLIR.IR.body(current_module)
1716+
region = MLIR.IR.Region()
1717+
push!(region, MLIR.IR.Block())
1718+
moduleop = MLIR.Dialects.builtin.module_(;
1719+
location, bodyRegion=region, sym_name=mod_name
1720+
)
1721+
MLIR.IR.rmfromparent!(moduleop)
1722+
push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module
17191723

1720-
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
1721-
fn = MLIR.IR.lookup(
1722-
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
1723-
)
1724+
top_level_block = MLIR.IR.Block(
1725+
MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false
1726+
)
1727+
fn = nothing
1728+
else
1729+
current_module = MLIR.IR.mmodule()
1730+
moduleop = MLIR.IR.Operation(current_module)
1731+
top_level_block = MLIR.IR.body(current_module)
1732+
fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call)
1733+
end
17241734

17251735
if isnothing(fn)
17261736
new_mod = parse(MLIR.IR.Module, code)
17271737
new_mod_op = MLIR.IR.Operation(new_mod)
17281738
body = MLIR.IR.body(new_mod)
17291739

17301740
operations = collect(MLIR.IR.OperationIterator(body))
1731-
for op in operations
1732-
if MLIR.IR.name(op) == func_op_kind
1733-
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
1734-
if fn_name == func_name
1735-
fn = op
1736-
end
1741+
idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations)
1742+
@assert idx !== nothing
1743+
op = operations[idx]
17371744

1738-
res = MLIR.IR.LogicalResult(
1739-
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
1740-
fn_name, name_to_call, new_mod_op
1741-
),
1742-
)
1743-
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
1744-
1745-
# Set function private
1746-
MLIR.IR.attr!(
1747-
op,
1748-
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1749-
MLIR.IR.Attribute("private"),
1750-
)
1751-
1752-
# Change function name
1753-
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
1754-
end
1755-
end
1745+
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
1746+
fn_name == func_name && (fn = op)
1747+
1748+
res = MLIR.IR.LogicalResult(
1749+
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op)
1750+
)
1751+
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
1752+
1753+
# Set function private
1754+
MLIR.IR.attr!(
1755+
op,
1756+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1757+
MLIR.IR.Attribute("private"),
1758+
)
1759+
1760+
# Change function name
1761+
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
17561762

17571763
for op in operations
17581764
MLIR.IR.rmfromparent!(op)
@@ -1764,7 +1770,7 @@ function _extract_function(
17641770
error("hlo_call: could not find function $func_name in the provided module")
17651771
end
17661772

1767-
return fn, name_to_call
1773+
return fn, name_to_call, mod_name
17681774
end
17691775

17701776
function triton_call(
@@ -1778,8 +1784,8 @@ function triton_call(
17781784
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
17791785
# TODO: other kwargs
17801786
)
1781-
_, name_to_call = _extract_function(
1782-
mlir_code; func_name, func_op_kind="tt.func", nested_module=true
1787+
_, name_to_call, mod_name = _extract_function(
1788+
mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location
17831789
)
17841790

17851791
enzymexla.triton_call(
@@ -1788,7 +1794,9 @@ function triton_call(
17881794
grid_z.mlir_data,
17891795
shmem.mlir_data,
17901796
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1791-
fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1797+
fn=MLIR.IR.SymbolRefAttribute(
1798+
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1799+
),
17921800
result_0=MLIR.IR.Type[],
17931801
location,
17941802
)
@@ -1826,7 +1834,9 @@ julia> Reactant.@jit(
18261834
func_name="main",
18271835
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18281836
)
1829-
fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1837+
fn, name_to_call, _ = _extract_function(
1838+
code; func_name, func_op_kind="func.func", location
1839+
)
18301840

18311841
ftype_attr = MLIR.IR.attr(fn, "function_type")
18321842
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)