Skip to content

Commit 357f1c0

Browse files
committed
feat: triton tracing works now finally
1 parent 30976c0 commit 357f1c0

File tree

4 files changed

+110
-49
lines changed

4 files changed

+110
-49
lines changed

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export default defineConfig({
131131
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
132132
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
133133
{ text: "Triton", link: "/api/dialects/triton" },
134+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
134135
{ text: "TPU", link: "/api/dialects/tpu" },
135136
{ text: "VHLO", link: "/api/dialects/vhlo" },
136137
],
@@ -221,6 +222,7 @@ export default defineConfig({
221222
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
222223
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
223224
{ text: "Triton", link: "/api/dialects/triton" },
225+
{ text: "TritonExt", link: "/api/dialects/tritonext" },
224226
{ text: "TPU", link: "/api/dialects/tpu" },
225227
{ text: "VHLO", link: "/api/dialects/vhlo" },
226228
],

docs/src/api/dialects/tritonext.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# TritonExt Dialect
6+
7+
Provides extensions to the Triton dialect.
8+
9+
```@autodocs
10+
Modules = [Reactant.MLIR.Dialects.triton_ext]
11+
```

src/Compiler.jl

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,57 +1305,89 @@ function optimization_passes(
13051305
end
13061306

13071307
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1308+
# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes
13081309
function triton_optimization_passes()
1309-
# TODO: check that all triton passes are included here
1310-
return join(
1310+
all_passes = join(
13111311
[
1312-
# convert passes
1313-
"convert-scf-to-cf",
1314-
"convert-cf-to-llvm",
1315-
"convert-index-to-llvm",
1316-
"convert-arith-to-llvm",
1317-
"convert-nvvm-to-llvm",
1318-
# common passes
13191312
"canonicalize",
1320-
# ttir passes
1313+
"triton-rewrite-tensor-pointer",
1314+
"canonicalize",
13211315
"triton-combine",
13221316
"triton-reorder-broadcast",
1323-
"triton-rewrite-tensor-pointer",
1324-
"triton-rewrite-tensor-descriptor-to-pointer",
1317+
"cse",
1318+
"symbol-dce",
13251319
"triton-loop-unroll",
1326-
"triton-licm",
1327-
"triton-loop-aware-cse",
1328-
# TODO: should num-warps and num-ctas be set for each kernel?
13291320
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330-
# ttgir passes
13311321
"tritongpu-coalesce",
1322+
"tritongpu-F32DotTC",
1323+
"triton-nvidia-gpu-plan-cta",
1324+
"tritongpu-remove-layout-conversions",
13321325
"tritongpu-optimize-thread-locality",
1326+
"tritongpu-accelerate-matmul",
1327+
"tritongpu-remove-layout-conversions",
1328+
"tritongpu-optimize-dot-operands",
1329+
"canonicalize",
1330+
"triton-nvidia-optimize-descriptor-encoding",
1331+
"triton-loop-aware-cse",
1332+
"tritongpu-fuse-nested-loops",
1333+
"canonicalize",
1334+
"triton-licm",
1335+
"tritongpu-optimize-accumulator-init",
13331336
"tritongpu-hoist-tmem-alloc",
1337+
"tritongpu-promote-lhs-to-tmem",
13341338
"tritongpu-assign-latencies",
1335-
"tritongpu-pipeline",
13361339
"tritongpu-schedule-loops",
13371340
"tritongpu-automatic-warp-specialization",
1341+
"tritongpu-partition-scheduling",
1342+
"tritongpu-load-mma-specialization",
1343+
"tritongpu-rewrite-partition-dependencies",
1344+
"sccp",
1345+
"cse",
1346+
"tritongpu-partition-loops",
1347+
"tritongpu-optimize-partition-warps",
1348+
"tritongpu-schedule-loops",
1349+
"tritongpu-pipeline",
1350+
"tritongpu-combine-tensor-select-and-if",
1351+
"triton-nvidia-gpu-remove-tmem-tokens",
1352+
"canonicalize",
1353+
"triton-loop-aware-cse",
13381354
"tritongpu-prefetch",
1339-
"tritongpu-accelerate-matmul",
1340-
"tritongpu-reorder-instructions",
1341-
"tritongpu-F32DotTC",
13421355
"tritongpu-optimize-dot-operands",
1356+
"canonicalize",
1357+
"tritongpu-coalesce-async-copy",
1358+
"triton-nvidia-optimize-tmem-layouts",
13431359
"tritongpu-remove-layout-conversions",
1360+
"triton-nvidia-interleave-tmem",
13441361
"tritongpu-reduce-data-duplication",
1345-
"tritongpu-hoist-tmem-alloc",
1346-
"tritongpu-fuse-nested-loops",
1347-
"tritongpu-rewrite-partition-dependencies",
1348-
"tritongpu-partition-loops",
1362+
"tritongpu-reorder-instructions",
1363+
"triton-loop-aware-cse",
1364+
"symbol-dce",
1365+
"triton-nvidia-tma-lowering",
1366+
"triton-nvidia-gpu-fence-insertion",
1367+
"sccp",
1368+
"canonicalize",
1369+
"triton-nvidia-mma-lowering",
13491370
"tritongpu-combine-tensor-select-and-if",
1350-
# ttgir to llvm passes
13511371
"tritongpu-allocate-warp-groups",
1372+
"convert-scf-to-cf",
13521373
"allocate-shared-memory",
1374+
"triton-tensor-memory-allocation",
13531375
"tritongpu-global-scratch-memory-allocation",
1354-
"tritongpu-optimize-accumulator-init",
1355-
"tritongpu-coalesce-async-copy",
1376+
# TODO: register the commented out passes
1377+
# "convert-triton-gpu-to-llvm",
1378+
"canonicalize",
1379+
"cse",
1380+
# "convert-nv-gpu-to-llvm",
1381+
# "convert-warp-specialize-to-llvm",
1382+
"reconcile-unrealized-casts",
1383+
"canonicalize",
1384+
"cse",
1385+
"symbol-dce",
1386+
"enable-line-info",
13561387
],
13571388
",",
13581389
)
1390+
return "triton_ext.module(builtin.module($(all_passes)))"
13591391
end
13601392

13611393
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate

src/Ops.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these
44
module Ops
55
using ..MLIR: MLIR
6-
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
6+
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext
77
using ..Reactant:
88
Reactant,
99
TracedRArray,
@@ -1749,32 +1749,52 @@ function _extract_function(
17491749
code::String;
17501750
func_name::String="main",
17511751
func_op_kind::String="func.func",
1752-
nested_module::Bool=false,
17531752
location::MLIR.IR.Location=MLIR.IR.Location(),
17541753
)
17551754
module_suffix = string(hash(code); base=16)
17561755
name_to_call = func_name * "_call_" * module_suffix
17571756
mod_name = func_name * "_module_" * module_suffix
17581757
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
17591758

1760-
if nested_module
1759+
use_ttext_module = split(func_op_kind, ".")[1] == "tt"
1760+
1761+
if use_ttext_module
1762+
tt_mod_name = func_name * "_tt_module_" * module_suffix
1763+
tt_region = MLIR.IR.Region()
1764+
tt_block = MLIR.IR.Block()
1765+
push!(tt_region, tt_block)
1766+
triton_mod_op = triton_ext.module_(;
1767+
location, bodyRegion=tt_region, sym_name=tt_mod_name
1768+
)
1769+
MLIR.IR.rmfromparent!(triton_mod_op)
1770+
push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module
1771+
17611772
region = MLIR.IR.Region()
17621773
push!(region, MLIR.IR.Block())
17631774
moduleop = MLIR.Dialects.builtin.module_(;
17641775
location, bodyRegion=region, sym_name=mod_name
17651776
)
17661777
MLIR.IR.rmfromparent!(moduleop)
1767-
push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module
1778+
push!(tt_block, moduleop) # insert into triton module
17681779

17691780
top_level_block = MLIR.IR.Block(
17701781
MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false
17711782
)
17721783
fn = nothing
1784+
1785+
symref = MLIR.IR.SymbolRefAttribute(
1786+
tt_mod_name,
1787+
MLIR.IR.Attribute[
1788+
MLIR.IR.FlatSymbolRefAttribute(mod_name),
1789+
MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1790+
],
1791+
)
17731792
else
17741793
current_module = MLIR.IR.mmodule()
17751794
moduleop = MLIR.IR.Operation(current_module)
17761795
top_level_block = MLIR.IR.body(current_module)
17771796
fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call)
1797+
symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call)
17781798
end
17791799

17801800
if isnothing(fn)
@@ -1795,12 +1815,14 @@ function _extract_function(
17951815
)
17961816
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
17971817

1798-
# Set function private
1799-
MLIR.IR.attr!(
1800-
op,
1801-
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1802-
MLIR.IR.Attribute("private"),
1803-
)
1818+
if !use_ttext_module
1819+
# Set function private
1820+
MLIR.IR.attr!(
1821+
op,
1822+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1823+
MLIR.IR.Attribute("private"),
1824+
)
1825+
end
18041826

18051827
# Change function name
18061828
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
@@ -1815,7 +1837,7 @@ function _extract_function(
18151837
error("hlo_call: could not find function $func_name in the provided module")
18161838
end
18171839

1818-
return fn, name_to_call, mod_name
1840+
return fn, symref
18191841
end
18201842

18211843
function triton_call(
@@ -1829,19 +1851,15 @@ function triton_call(
18291851
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18301852
# TODO: other kwargs
18311853
)
1832-
_, name_to_call, mod_name = _extract_function(
1833-
mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location
1834-
)
1854+
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
18351855

1836-
enzymexla.triton_call(
1856+
triton_ext.call(
18371857
grid_x.mlir_data,
18381858
grid_y.mlir_data,
18391859
grid_z.mlir_data,
18401860
shmem.mlir_data,
18411861
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1842-
fn=MLIR.IR.SymbolRefAttribute(
1843-
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1844-
),
1862+
fn=symref,
18451863
result_0=MLIR.IR.Type[],
18461864
location,
18471865
)
@@ -1879,9 +1897,7 @@ julia> Reactant.@jit(
18791897
func_name="main",
18801898
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18811899
)
1882-
fn, name_to_call, _ = _extract_function(
1883-
code; func_name, func_op_kind="func.func", location
1884-
)
1900+
fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location)
18851901

18861902
ftype_attr = MLIR.IR.attr(fn, "function_type")
18871903
ftype = MLIR.IR.Type(ftype_attr)
@@ -1898,7 +1914,7 @@ julia> Reactant.@jit(
18981914
call = MLIR.Dialects.func.call(
18991915
operands;
19001916
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
1901-
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1917+
callee=symref,
19021918
location,
19031919
)
19041920

0 commit comments

Comments
 (0)