Skip to content

Commit 8ae9ebe

Browse files
committed
feat: triton tracing works now finally
1 parent fb4f002 commit 8ae9ebe

File tree

4 files changed

+110
-49
lines changed

4 files changed

+110
-49
lines changed

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export default defineConfig({
131131
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
132132
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
133133
{ text: "Triton", link: "/api/dialects/triton" },
134+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
134135
{ text: "TPU", link: "/api/dialects/tpu" },
135136
{ text: "VHLO", link: "/api/dialects/vhlo" },
136137
],
@@ -221,6 +222,7 @@ export default defineConfig({
221222
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
222223
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
223224
{ text: "Triton", link: "/api/dialects/triton" },
225+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
224226
{ text: "TPU", link: "/api/dialects/tpu" },
225227
{ text: "VHLO", link: "/api/dialects/vhlo" },
226228
],

docs/src/api/dialects/tritonext.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# TritonExt Dialect
6+
7+
Provides extensions to the Triton dialect.
8+
9+
```@autodocs
10+
Modules = [Reactant.MLIR.Dialects.triton_ext]
11+
```

src/Compiler.jl

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,57 +1307,89 @@ function optimization_passes(
13071307
end
13081308

13091309
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1310+
# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes
13101311
function triton_optimization_passes()
1311-
# TODO: check that all triton passes are included here
1312-
return join(
1312+
all_passes = join(
13131313
[
1314-
# convert passes
1315-
"convert-scf-to-cf",
1316-
"convert-cf-to-llvm",
1317-
"convert-index-to-llvm",
1318-
"convert-arith-to-llvm",
1319-
"convert-nvvm-to-llvm",
1320-
# common passes
13211314
"canonicalize",
1322-
# ttir passes
1315+
"triton-rewrite-tensor-pointer",
1316+
"canonicalize",
13231317
"triton-combine",
13241318
"triton-reorder-broadcast",
1325-
"triton-rewrite-tensor-pointer",
1326-
"triton-rewrite-tensor-descriptor-to-pointer",
1319+
"cse",
1320+
"symbol-dce",
13271321
"triton-loop-unroll",
1328-
"triton-licm",
1329-
"triton-loop-aware-cse",
1330-
# TODO: should num-warps and num-ctas be set for each kernel?
13311322
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1332-
# ttgir passes
13331323
"tritongpu-coalesce",
1324+
"tritongpu-F32DotTC",
1325+
"triton-nvidia-gpu-plan-cta",
1326+
"tritongpu-remove-layout-conversions",
13341327
"tritongpu-optimize-thread-locality",
1328+
"tritongpu-accelerate-matmul",
1329+
"tritongpu-remove-layout-conversions",
1330+
"tritongpu-optimize-dot-operands",
1331+
"canonicalize",
1332+
"triton-nvidia-optimize-descriptor-encoding",
1333+
"triton-loop-aware-cse",
1334+
"tritongpu-fuse-nested-loops",
1335+
"canonicalize",
1336+
"triton-licm",
1337+
"tritongpu-optimize-accumulator-init",
13351338
"tritongpu-hoist-tmem-alloc",
1339+
"tritongpu-promote-lhs-to-tmem",
13361340
"tritongpu-assign-latencies",
1337-
"tritongpu-pipeline",
13381341
"tritongpu-schedule-loops",
13391342
"tritongpu-automatic-warp-specialization",
1343+
"tritongpu-partition-scheduling",
1344+
"tritongpu-load-mma-specialization",
1345+
"tritongpu-rewrite-partition-dependencies",
1346+
"sccp",
1347+
"cse",
1348+
"tritongpu-partition-loops",
1349+
"tritongpu-optimize-partition-warps",
1350+
"tritongpu-schedule-loops",
1351+
"tritongpu-pipeline",
1352+
"tritongpu-combine-tensor-select-and-if",
1353+
"triton-nvidia-gpu-remove-tmem-tokens",
1354+
"canonicalize",
1355+
"triton-loop-aware-cse",
13401356
"tritongpu-prefetch",
1341-
"tritongpu-accelerate-matmul",
1342-
"tritongpu-reorder-instructions",
1343-
"tritongpu-F32DotTC",
13441357
"tritongpu-optimize-dot-operands",
1358+
"canonicalize",
1359+
"tritongpu-coalesce-async-copy",
1360+
"triton-nvidia-optimize-tmem-layouts",
13451361
"tritongpu-remove-layout-conversions",
1362+
"triton-nvidia-interleave-tmem",
13461363
"tritongpu-reduce-data-duplication",
1347-
"tritongpu-hoist-tmem-alloc",
1348-
"tritongpu-fuse-nested-loops",
1349-
"tritongpu-rewrite-partition-dependencies",
1350-
"tritongpu-partition-loops",
1364+
"tritongpu-reorder-instructions",
1365+
"triton-loop-aware-cse",
1366+
"symbol-dce",
1367+
"triton-nvidia-tma-lowering",
1368+
"triton-nvidia-gpu-fence-insertion",
1369+
"sccp",
1370+
"canonicalize",
1371+
"triton-nvidia-mma-lowering",
13511372
"tritongpu-combine-tensor-select-and-if",
1352-
# ttgir to llvm passes
13531373
"tritongpu-allocate-warp-groups",
1374+
"convert-scf-to-cf",
13541375
"allocate-shared-memory",
1376+
"triton-tensor-memory-allocation",
13551377
"tritongpu-global-scratch-memory-allocation",
1356-
"tritongpu-optimize-accumulator-init",
1357-
"tritongpu-coalesce-async-copy",
1378+
# TODO: register the commented out passes
1379+
# "convert-triton-gpu-to-llvm",
1380+
"canonicalize",
1381+
"cse",
1382+
# "convert-nv-gpu-to-llvm",
1383+
# "convert-warp-specialize-to-llvm",
1384+
"reconcile-unrealized-casts",
1385+
"canonicalize",
1386+
"cse",
1387+
"symbol-dce",
1388+
"enable-line-info",
13581389
],
13591390
",",
13601391
)
1392+
return "triton_ext.module(builtin.module($(all_passes)))"
13611393
end
13621394

13631395
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate

src/Ops.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these
44
module Ops
55
using ..MLIR: MLIR
6-
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
6+
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext
77
using ..Reactant:
88
Reactant,
99
TracedRArray,
@@ -1749,32 +1749,52 @@ function _extract_function(
17491749
code::String;
17501750
func_name::String="main",
17511751
func_op_kind::String="func.func",
1752-
nested_module::Bool=false,
17531752
location::MLIR.IR.Location=MLIR.IR.Location(),
17541753
)
17551754
module_suffix = string(hash(code); base=16)
17561755
name_to_call = func_name * "_call_" * module_suffix
17571756
mod_name = func_name * "_module_" * module_suffix
17581757
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
17591758

1760-
if nested_module
1759+
use_ttext_module = split(func_op_kind, ".")[1] == "tt"
1760+
1761+
if use_ttext_module
1762+
tt_mod_name = func_name * "_tt_module_" * module_suffix
1763+
tt_region = MLIR.IR.Region()
1764+
tt_block = MLIR.IR.Block()
1765+
push!(tt_region, tt_block)
1766+
triton_mod_op = triton_ext.module_(;
1767+
location, bodyRegion=tt_region, sym_name=tt_mod_name
1768+
)
1769+
MLIR.IR.rmfromparent!(triton_mod_op)
1770+
push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module
1771+
17611772
region = MLIR.IR.Region()
17621773
push!(region, MLIR.IR.Block())
17631774
moduleop = MLIR.Dialects.builtin.module_(;
17641775
location, bodyRegion=region, sym_name=mod_name
17651776
)
17661777
MLIR.IR.rmfromparent!(moduleop)
1767-
push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module
1778+
push!(tt_block, moduleop) # insert into triton module
17681779

17691780
top_level_block = MLIR.IR.Block(
17701781
MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false
17711782
)
17721783
fn = nothing
1784+
1785+
symref = MLIR.IR.SymbolRefAttribute(
1786+
tt_mod_name,
1787+
MLIR.IR.Attribute[
1788+
MLIR.IR.FlatSymbolRefAttribute(mod_name),
1789+
MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1790+
],
1791+
)
17731792
else
17741793
current_module = MLIR.IR.mmodule()
17751794
moduleop = MLIR.IR.Operation(current_module)
17761795
top_level_block = MLIR.IR.body(current_module)
17771796
fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call)
1797+
symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call)
17781798
end
17791799

17801800
if isnothing(fn)
@@ -1795,12 +1815,14 @@ function _extract_function(
17951815
)
17961816
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
17971817

1798-
# Set function private
1799-
MLIR.IR.attr!(
1800-
op,
1801-
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1802-
MLIR.IR.Attribute("private"),
1803-
)
1818+
if !use_ttext_module
1819+
# Set function private
1820+
MLIR.IR.attr!(
1821+
op,
1822+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1823+
MLIR.IR.Attribute("private"),
1824+
)
1825+
end
18041826

18051827
# Change function name
18061828
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
@@ -1815,7 +1837,7 @@ function _extract_function(
18151837
error("hlo_call: could not find function $func_name in the provided module")
18161838
end
18171839

1818-
return fn, name_to_call, mod_name
1840+
return fn, symref
18191841
end
18201842

18211843
function triton_call(
@@ -1829,19 +1851,15 @@ function triton_call(
18291851
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18301852
# TODO: other kwargs
18311853
)
1832-
_, name_to_call, mod_name = _extract_function(
1833-
mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location
1834-
)
1854+
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
18351855

1836-
enzymexla.triton_call(
1856+
triton_ext.call(
18371857
grid_x.mlir_data,
18381858
grid_y.mlir_data,
18391859
grid_z.mlir_data,
18401860
shmem.mlir_data,
18411861
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1842-
fn=MLIR.IR.SymbolRefAttribute(
1843-
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1844-
),
1862+
fn=symref,
18451863
result_0=MLIR.IR.Type[],
18461864
location,
18471865
)
@@ -1879,9 +1897,7 @@ julia> Reactant.@jit(
18791897
func_name="main",
18801898
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18811899
)
1882-
fn, name_to_call, _ = _extract_function(
1883-
code; func_name, func_op_kind="func.func", location
1884-
)
1900+
fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location)
18851901

18861902
ftype_attr = MLIR.IR.attr(fn, "function_type")
18871903
ftype = MLIR.IR.Type(ftype_attr)
@@ -1898,7 +1914,7 @@ julia> Reactant.@jit(
18981914
call = MLIR.Dialects.func.call(
18991915
operands;
19001916
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
1901-
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1917+
callee=symref,
19021918
location,
19031919
)
19041920

0 commit comments

Comments
 (0)