Skip to content

Commit 4876110

Browse files
committed
feat: triton tracing works now finally
1 parent 7bee86b commit 4876110

File tree

5 files changed

+111
-50
lines changed

5 files changed

+111
-50
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "9867ac059bb2f312a1a6d559d2b41d8ba333a589"
7+
ENZYMEXLA_COMMIT = "0d94adbc3a182ea6dbdc9d4103022beb7f1d20b9"
88

99
ENZYMEXLA_SHA256 = ""
1010

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export default defineConfig({
131131
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
132132
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
133133
{ text: "Triton", link: "/api/dialects/triton" },
134+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
134135
{ text: "TPU", link: "/api/dialects/tpu" },
135136
{ text: "VHLO", link: "/api/dialects/vhlo" },
136137
],
@@ -221,6 +222,7 @@ export default defineConfig({
221222
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
222223
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
223224
{ text: "Triton", link: "/api/dialects/triton" },
225+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
224226
{ text: "TPU", link: "/api/dialects/tpu" },
225227
{ text: "VHLO", link: "/api/dialects/vhlo" },
226228
],

docs/src/api/dialects/tritonext.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# TritonExt Dialect
6+
7+
Provides extensions to the Triton dialect.
8+
9+
```@autodocs
10+
Modules = [Reactant.MLIR.Dialects.triton_ext]
11+
```

src/Compiler.jl

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,57 +1298,89 @@ function optimization_passes(
12981298
end
12991299

13001300
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1301+
# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes
13011302
function triton_optimization_passes()
1302-
# TODO: check that all triton passes are included here
1303-
return join(
1303+
all_passes = join(
13041304
[
1305-
# convert passes
1306-
"convert-scf-to-cf",
1307-
"convert-cf-to-llvm",
1308-
"convert-index-to-llvm",
1309-
"convert-arith-to-llvm",
1310-
"convert-nvvm-to-llvm",
1311-
# common passes
13121305
"canonicalize",
1313-
# ttir passes
1306+
"triton-rewrite-tensor-pointer",
1307+
"canonicalize",
13141308
"triton-combine",
13151309
"triton-reorder-broadcast",
1316-
"triton-rewrite-tensor-pointer",
1317-
"triton-rewrite-tensor-descriptor-to-pointer",
1310+
"cse",
1311+
"symbol-dce",
13181312
"triton-loop-unroll",
1319-
"triton-licm",
1320-
"triton-loop-aware-cse",
1321-
# TODO: should num-warps and num-ctas be set for each kernel?
13221313
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323-
# ttgir passes
13241314
"tritongpu-coalesce",
1315+
"tritongpu-F32DotTC",
1316+
"triton-nvidia-gpu-plan-cta",
1317+
"tritongpu-remove-layout-conversions",
13251318
"tritongpu-optimize-thread-locality",
1319+
"tritongpu-accelerate-matmul",
1320+
"tritongpu-remove-layout-conversions",
1321+
"tritongpu-optimize-dot-operands",
1322+
"canonicalize",
1323+
"triton-nvidia-optimize-descriptor-encoding",
1324+
"triton-loop-aware-cse",
1325+
"tritongpu-fuse-nested-loops",
1326+
"canonicalize",
1327+
"triton-licm",
1328+
"tritongpu-optimize-accumulator-init",
13261329
"tritongpu-hoist-tmem-alloc",
1330+
"tritongpu-promote-lhs-to-tmem",
13271331
"tritongpu-assign-latencies",
1328-
"tritongpu-pipeline",
13291332
"tritongpu-schedule-loops",
13301333
"tritongpu-automatic-warp-specialization",
1334+
"tritongpu-partition-scheduling",
1335+
"tritongpu-load-mma-specialization",
1336+
"tritongpu-rewrite-partition-dependencies",
1337+
"sccp",
1338+
"cse",
1339+
"tritongpu-partition-loops",
1340+
"tritongpu-optimize-partition-warps",
1341+
"tritongpu-schedule-loops",
1342+
"tritongpu-pipeline",
1343+
"tritongpu-combine-tensor-select-and-if",
1344+
"triton-nvidia-gpu-remove-tmem-tokens",
1345+
"canonicalize",
1346+
"triton-loop-aware-cse",
13311347
"tritongpu-prefetch",
1332-
"tritongpu-accelerate-matmul",
1333-
"tritongpu-reorder-instructions",
1334-
"tritongpu-F32DotTC",
13351348
"tritongpu-optimize-dot-operands",
1349+
"canonicalize",
1350+
"tritongpu-coalesce-async-copy",
1351+
"triton-nvidia-optimize-tmem-layouts",
13361352
"tritongpu-remove-layout-conversions",
1353+
"triton-nvidia-interleave-tmem",
13371354
"tritongpu-reduce-data-duplication",
1338-
"tritongpu-hoist-tmem-alloc",
1339-
"tritongpu-fuse-nested-loops",
1340-
"tritongpu-rewrite-partition-dependencies",
1341-
"tritongpu-partition-loops",
1355+
"tritongpu-reorder-instructions",
1356+
"triton-loop-aware-cse",
1357+
"symbol-dce",
1358+
"triton-nvidia-tma-lowering",
1359+
"triton-nvidia-gpu-fence-insertion",
1360+
"sccp",
1361+
"canonicalize",
1362+
"triton-nvidia-mma-lowering",
13421363
"tritongpu-combine-tensor-select-and-if",
1343-
# ttgir to llvm passes
13441364
"tritongpu-allocate-warp-groups",
1365+
"convert-scf-to-cf",
13451366
"allocate-shared-memory",
1367+
"triton-tensor-memory-allocation",
13461368
"tritongpu-global-scratch-memory-allocation",
1347-
"tritongpu-optimize-accumulator-init",
1348-
"tritongpu-coalesce-async-copy",
1369+
# TODO: register the commented out passes
1370+
# "convert-triton-gpu-to-llvm",
1371+
"canonicalize",
1372+
"cse",
1373+
# "convert-nv-gpu-to-llvm",
1374+
# "convert-warp-specialize-to-llvm",
1375+
"reconcile-unrealized-casts",
1376+
"canonicalize",
1377+
"cse",
1378+
"symbol-dce",
1379+
"enable-line-info",
13491380
],
13501381
",",
13511382
)
1383+
return "triton_ext.module(builtin.module($(all_passes)))"
13521384
end
13531385

13541386
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate

src/Ops.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these
44
module Ops
55
using ..MLIR: MLIR
6-
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
6+
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext
77
using ..Reactant:
88
Reactant,
99
TracedRArray,
@@ -1704,32 +1704,52 @@ function _extract_function(
17041704
code::String;
17051705
func_name::String="main",
17061706
func_op_kind::String="func.func",
1707-
nested_module::Bool=false,
17081707
location::MLIR.IR.Location=MLIR.IR.Location(),
17091708
)
17101709
module_suffix = string(hash(code); base=16)
17111710
name_to_call = func_name * "_call_" * module_suffix
17121711
mod_name = func_name * "_module_" * module_suffix
17131712
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
17141713

1715-
if nested_module
1714+
use_ttext_module = split(func_op_kind, ".")[1] == "tt"
1715+
1716+
if use_ttext_module
1717+
tt_mod_name = func_name * "_tt_module_" * module_suffix
1718+
tt_region = MLIR.IR.Region()
1719+
tt_block = MLIR.IR.Block()
1720+
push!(tt_region, tt_block)
1721+
triton_mod_op = triton_ext.module_(;
1722+
location, bodyRegion=tt_region, sym_name=tt_mod_name
1723+
)
1724+
MLIR.IR.rmfromparent!(triton_mod_op)
1725+
push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module
1726+
17161727
region = MLIR.IR.Region()
17171728
push!(region, MLIR.IR.Block())
17181729
moduleop = MLIR.Dialects.builtin.module_(;
17191730
location, bodyRegion=region, sym_name=mod_name
17201731
)
17211732
MLIR.IR.rmfromparent!(moduleop)
1722-
push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module
1733+
push!(tt_block, moduleop) # insert into triton module
17231734

17241735
top_level_block = MLIR.IR.Block(
17251736
MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false
17261737
)
17271738
fn = nothing
1739+
1740+
symref = MLIR.IR.SymbolRefAttribute(
1741+
tt_mod_name,
1742+
MLIR.IR.Attribute[
1743+
MLIR.IR.FlatSymbolRefAttribute(mod_name),
1744+
MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1745+
],
1746+
)
17281747
else
17291748
current_module = MLIR.IR.mmodule()
17301749
moduleop = MLIR.IR.Operation(current_module)
17311750
top_level_block = MLIR.IR.body(current_module)
17321751
fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call)
1752+
symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call)
17331753
end
17341754

17351755
if isnothing(fn)
@@ -1750,12 +1770,14 @@ function _extract_function(
17501770
)
17511771
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
17521772

1753-
# Set function private
1754-
MLIR.IR.attr!(
1755-
op,
1756-
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1757-
MLIR.IR.Attribute("private"),
1758-
)
1773+
if !use_ttext_module
1774+
# Set function private
1775+
MLIR.IR.attr!(
1776+
op,
1777+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1778+
MLIR.IR.Attribute("private"),
1779+
)
1780+
end
17591781

17601782
# Change function name
17611783
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
@@ -1770,7 +1792,7 @@ function _extract_function(
17701792
error("hlo_call: could not find function $func_name in the provided module")
17711793
end
17721794

1773-
return fn, name_to_call, mod_name
1795+
return fn, symref
17741796
end
17751797

17761798
function triton_call(
@@ -1784,19 +1806,15 @@ function triton_call(
17841806
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
17851807
# TODO: other kwargs
17861808
)
1787-
_, name_to_call, mod_name = _extract_function(
1788-
mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location
1789-
)
1809+
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
17901810

1791-
enzymexla.triton_call(
1811+
triton_ext.call(
17921812
grid_x.mlir_data,
17931813
grid_y.mlir_data,
17941814
grid_z.mlir_data,
17951815
shmem.mlir_data,
17961816
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1797-
fn=MLIR.IR.SymbolRefAttribute(
1798-
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1799-
),
1817+
fn=symref,
18001818
result_0=MLIR.IR.Type[],
18011819
location,
18021820
)
@@ -1834,9 +1852,7 @@ julia> Reactant.@jit(
18341852
func_name="main",
18351853
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18361854
)
1837-
fn, name_to_call, _ = _extract_function(
1838-
code; func_name, func_op_kind="func.func", location
1839-
)
1855+
fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location)
18401856

18411857
ftype_attr = MLIR.IR.attr(fn, "function_type")
18421858
ftype = MLIR.IR.Type(ftype_attr)
@@ -1853,7 +1869,7 @@ julia> Reactant.@jit(
18531869
call = MLIR.Dialects.func.call(
18541870
operands;
18551871
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
1856-
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1872+
callee=symref,
18571873
location,
18581874
)
18591875

0 commit comments

Comments
 (0)