Skip to content

Commit 013180c

Browse files
committed
feat: more triton passes + keep triton func in a separate module
1 parent 933f67a commit 013180c

File tree

5 files changed

+83
-7
lines changed

5 files changed

+83
-7
lines changed

deps/ReactantExtra/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,9 @@ cc_library(
979979
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
980980
"-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion",
981981
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
982+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
983+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
984+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
982985
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
983986
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
984987
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothi
6060
signature_string(x::T) where {T<:Number} = string(x), x
6161
signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6262

63+
# TODO: better name for hints?
6364
function overlayed_pycall_with_triton(
6465
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
6566
)
@@ -95,8 +96,11 @@ function overlayed_pycall_with_triton(
9596
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
9697
)
9798

98-
# TODO: check that we are using CUDA. Get compute_capability from the target
99-
target = triton.backends.compiler.GPUTarget("cuda", 80, 32)
99+
target = triton.backends.compiler.GPUTarget(
100+
"cuda",
101+
parse(Int, Reactant.Compiler.cubinChip[][4:end]),
102+
Reactant.Compiler.cuWarpSize[],
103+
)
100104
backend = triton.compiler.make_backend(target)
101105
options = backend.parse_options(
102106
pydict(
@@ -111,7 +115,7 @@ function overlayed_pycall_with_triton(
111115
ccinfo = triton.compile(src; target=target, options=options.__dict__)
112116

113117
@opcall triton_call(
114-
pyconvert(String, ccinfo.asm["ttir"]),
118+
pyconvert(String, ccinfo.asm["source"]),
115119
filter(x -> x isa Reactant.TracedType, args)...;
116120
func_name=pyconvert(String, ccinfo.metadata.name),
117121
grid_x=@opcall(constant(grid[1])),

src/Compiler.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1298,9 +1298,66 @@ function optimization_passes(
12981298
push!(passes, "remove-duplicate-func-def")
12991299
end
13001300
push!(passes, func_passes)
1301+
if backend == "cuda"
1302+
push!(passes, triton_optimization_passes())
1303+
end
13011304
return join(passes, ',')
13021305
end
13031306

1307+
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1308+
function triton_optimization_passes()
1309+
# TODO: check that all triton passes are included here
1310+
return join(
1311+
[
1312+
# convert passes
1313+
"convert-scf-to-cf",
1314+
"convert-cf-to-llvm",
1315+
"convert-index-to-llvm",
1316+
"convert-arith-to-llvm",
1317+
"convert-nvvm-to-llvm",
1318+
# common passes
1319+
"canonicalize",
1320+
# # ttir passes
1321+
# "triton-combine",
1322+
# "triton-reorder-broadcast",
1323+
# "triton-rewrite-tensor-pointer",
1324+
# "triton-rewrite-tensor-descriptor-to-pointer",
1325+
# "triton-loop-unroll",
1326+
# "triton-licm",
1327+
# "triton-loop-aware-cse",
1328+
# # TODO: should num-warps and num-ctas be set for each kernel?
1329+
# "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330+
# # ttgir passes
1331+
# "tritongpu-coalesce",
1332+
# "tritongpu-optimize-thread-locality",
1333+
# "tritongpu-hoist-tmem-alloc",
1334+
# "tritongpu-assign-latencies",
1335+
# "tritongpu-pipeline",
1336+
# "tritongpu-schedule-loops",
1337+
# "tritongpu-automatic-warp-specialization",
1338+
# "tritongpu-prefetch",
1339+
# "tritongpu-accelerate-matmul",
1340+
# "tritongpu-reorder-instructions",
1341+
# "tritongpu-F32DotTC",
1342+
# "tritongpu-optimize-dot-operands",
1343+
# "tritongpu-remove-layout-conversions",
1344+
# "tritongpu-reduce-data-duplication",
1345+
# "tritongpu-hoist-tmem-alloc",
1346+
# "tritongpu-fuse-nested-loops",
1347+
# "tritongpu-rewrite-partition-dependencies",
1348+
# "tritongpu-partition-loops",
1349+
# "tritongpu-combine-tensor-select-and-if",
1350+
# # ttgir to llvm passes
1351+
# "tritongpu-allocate-warp-groups",
1352+
# "allocate-shared-memory",
1353+
# "tritongpu-global-scratch-memory-allocation",
1354+
# "tritongpu-optimize-accumulator-init",
1355+
# "tritongpu-coalesce-async-copy",
1356+
],
1357+
",",
1358+
)
1359+
end
1360+
13041361
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
13051362
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
13061363
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
@@ -2254,7 +2311,8 @@ function compile_mlir!(
22542311
end
22552312
end
22562313

2257-
run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
2314+
# XXX: re-enable this pass
2315+
# run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
22582316

22592317
func_op = MLIR.API.mlirSymbolTableLookup(
22602318
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname

src/Ops.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,12 +1746,20 @@ end
17461746
_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix
17471747

17481748
function _extract_function(
1749-
code::String; func_name::String="main", func_op_kind::String="func.func"
1749+
code::String;
1750+
func_name::String="main",
1751+
func_op_kind::String="func.func",
1752+
nested_module::Bool=false,
17501753
)
17511754
module_suffix = string(hash(code); base=16)
17521755
name_to_call = _new_function_name(func_name, module_suffix)
17531756

17541757
current_module = MLIR.IR.mmodule()
1758+
if nested_module
1759+
new_module = MLIR.IR.Module()
1760+
push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true))
1761+
current_module = new_module
1762+
end
17551763
top_level_block = MLIR.IR.body(current_module)
17561764

17571765
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
@@ -1815,7 +1823,9 @@ function triton_call(
18151823
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18161824
# TODO: other kwargs
18171825
)
1818-
_, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1826+
_, name_to_call = _extract_function(
1827+
mlir_code; func_name, func_op_kind="tt.func", nested_module=true
1828+
)
18191829

18201830
enzymexla.triton_call(
18211831
grid_x.mlir_data,

src/mlir/IR/Module.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false)
5252
5353
Views the module as a generic operation.
5454
"""
55-
Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false)
55+
Operation(module_::Module, owned::Bool=false) =
56+
Operation(API.mlirModuleGetOperation(module_), owned)
5657

5758
function Base.show(io::IO, module_::Module)
5859
return show(io, Operation(module_))

0 commit comments

Comments
 (0)