Skip to content

Commit 5b37e06

Browse files
committed
feat: more triton passes + keep triton func in a separate module
1 parent 4e953a5 commit 5b37e06

File tree

9 files changed

+136
-18
lines changed

9 files changed

+136
-18
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,8 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices,
487487
return client.release();
488488
}
489489
#else
490-
*error = "ReactantExtra was not built with GPU support";
491-
return nullptr;
490+
*error = "ReactantExtra was not built with GPU support";
491+
return nullptr;
492492
#endif
493493
}
494494

@@ -716,16 +716,56 @@ std::vector<int64_t> row_major(int64_t dim) {
716716
static void noop() {}
717717

718718
#ifdef REACTANT_CUDA
719+
719720
#include "third_party/gpus/cuda/include/cuda.h"
721+
720722
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() {
721723
int32_t data;
722724
ReactantHandleCuResult(cuDriverGetVersion(&data));
723725
return data;
724726
}
727+
725728
REACTANT_ABI int32_t ReactantHermeticCudaGetVersion() { return CUDA_VERSION; }
729+
730+
REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMajor() {
731+
CUdevice cuDevice;
732+
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
733+
int major;
734+
ReactantHandleCuResult(cuDeviceGetAttribute(
735+
&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice));
736+
return major;
737+
}
738+
739+
REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() {
740+
CUdevice cuDevice;
741+
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
742+
int minor;
743+
ReactantHandleCuResult(cuDeviceGetAttribute(
744+
&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice));
745+
return minor;
746+
}
747+
748+
REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() {
749+
CUdevice cuDevice;
750+
ReactantHandleCuResult(cuDeviceGet(&cuDevice, 0));
751+
int warpSize;
752+
ReactantHandleCuResult(cuDeviceGetAttribute(
753+
&warpSize, CU_DEVICE_ATTRIBUTE_WARP_SIZE, cuDevice));
754+
return warpSize;
755+
}
756+
726757
#else
758+
727759
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
760+
728761
REACTANT_ABI int32_t ReactantHermeticCudaGetVersion() { return 0; }
762+
763+
REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMajor() { return 0; }
764+
765+
REACTANT_ABI int32_t ReactantCudaDeviceGetComputeCapalilityMinor() { return 0; }
766+
767+
REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }
768+
729769
#endif
730770

731771
REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {

deps/ReactantExtra/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,9 @@ cc_library(
979979
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
980980
"-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion",
981981
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
982+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor",
983+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
984+
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
982985
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
983986
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
984987
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "fba59b4000ea352b145a14e1384b8c2940299987"
7+
ENZYMEXLA_COMMIT = "b59185c7586783a17d9486e682307ae89c713964"
88

99
ENZYMEXLA_SHA256 = ""
1010

ext/ReactantCUDAExt.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
14601460
return newa
14611461
end
14621462

1463-
function __init__()
1464-
if CUDA.functional() && !Reactant.precompiling()
1465-
cap = CUDA.capability(CUDA.device())
1466-
Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)"
1467-
end
1468-
return nothing
1469-
end
1470-
14711463
# In Julia v1.11.3 precompiling this module caches bad code:
14721464
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
14731465
@static if !Sys.isapple()

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothi
6060
signature_string(x::T) where {T<:Number} = string(x), x
6161
signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6262

63+
# TODO: better name for hints?
6364
function overlayed_pycall_with_triton(
6465
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
6566
)
@@ -95,8 +96,11 @@ function overlayed_pycall_with_triton(
9596
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
9697
)
9798

98-
# TODO: check that we are using CUDA. Get compute_capability from the target
99-
target = triton.backends.compiler.GPUTarget("cuda", 80, 32)
99+
target = triton.backends.compiler.GPUTarget(
100+
"cuda",
101+
parse(Int, Reactant.Compiler.cubinChip[][4:end]),
102+
Reactant.Compiler.cuWarpSize[],
103+
)
100104
backend = triton.compiler.make_backend(target)
101105
options = backend.parse_options(
102106
pydict(
@@ -111,7 +115,7 @@ function overlayed_pycall_with_triton(
111115
ccinfo = triton.compile(src; target=target, options=options.__dict__)
112116

113117
@opcall triton_call(
114-
pyconvert(String, ccinfo.asm["ttir"]),
118+
pyconvert(String, ccinfo.asm["source"]),
115119
filter(x -> x isa Reactant.TracedType, args)...;
116120
func_name=pyconvert(String, ccinfo.metadata.name),
117121
grid_x=@opcall(constant(grid[1])),

src/Compiler.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,9 +1291,66 @@ function optimization_passes(
12911291
push!(passes, "remove-duplicate-func-def")
12921292
end
12931293
push!(passes, func_passes)
1294+
if backend == "cuda"
1295+
push!(passes, triton_optimization_passes())
1296+
end
12941297
return join(passes, ',')
12951298
end
12961299

1300+
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
1301+
function triton_optimization_passes()
1302+
# TODO: check that all triton passes are included here
1303+
return join(
1304+
[
1305+
# convert passes
1306+
"convert-scf-to-cf",
1307+
"convert-cf-to-llvm",
1308+
"convert-index-to-llvm",
1309+
"convert-arith-to-llvm",
1310+
"convert-nvvm-to-llvm",
1311+
# common passes
1312+
"canonicalize",
1313+
# # ttir passes
1314+
# "triton-combine",
1315+
# "triton-reorder-broadcast",
1316+
# "triton-rewrite-tensor-pointer",
1317+
# "triton-rewrite-tensor-descriptor-to-pointer",
1318+
# "triton-loop-unroll",
1319+
# "triton-licm",
1320+
# "triton-loop-aware-cse",
1321+
# # TODO: should num-warps and num-ctas be set for each kernel?
1322+
# "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323+
# # ttgir passes
1324+
# "tritongpu-coalesce",
1325+
# "tritongpu-optimize-thread-locality",
1326+
# "tritongpu-hoist-tmem-alloc",
1327+
# "tritongpu-assign-latencies",
1328+
# "tritongpu-pipeline",
1329+
# "tritongpu-schedule-loops",
1330+
# "tritongpu-automatic-warp-specialization",
1331+
# "tritongpu-prefetch",
1332+
# "tritongpu-accelerate-matmul",
1333+
# "tritongpu-reorder-instructions",
1334+
# "tritongpu-F32DotTC",
1335+
# "tritongpu-optimize-dot-operands",
1336+
# "tritongpu-remove-layout-conversions",
1337+
# "tritongpu-reduce-data-duplication",
1338+
# "tritongpu-hoist-tmem-alloc",
1339+
# "tritongpu-fuse-nested-loops",
1340+
# "tritongpu-rewrite-partition-dependencies",
1341+
# "tritongpu-partition-loops",
1342+
# "tritongpu-combine-tensor-select-and-if",
1343+
# # ttgir to llvm passes
1344+
# "tritongpu-allocate-warp-groups",
1345+
# "allocate-shared-memory",
1346+
# "tritongpu-global-scratch-memory-allocation",
1347+
# "tritongpu-optimize-accumulator-init",
1348+
# "tritongpu-coalesce-async-copy",
1349+
],
1350+
",",
1351+
)
1352+
end
1353+
12971354
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
12981355
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
12991356
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
@@ -1425,6 +1482,7 @@ const cubinChip = Ref{String}("sm_60")
14251482
const cubinFormat = Ref{String}("bin")
14261483
const cuindexBitWidth = Ref{Int}(32)
14271484
const cuOptLevel = Ref{Int}(2)
1485+
const cuWarpSize = Ref{Int}(32)
14281486
# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14291487
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
14301488
# We see relevant ptx version is CUDA 12.6 -> 85
@@ -2245,7 +2303,8 @@ function compile_mlir!(
22452303
end
22462304
end
22472305

2248-
run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
2306+
# XXX: re-enable this pass
2307+
# run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
22492308

22502309
func_op = MLIR.API.mlirSymbolTableLookup(
22512310
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname

src/Ops.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,12 +1701,20 @@ end
17011701
_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix
17021702

17031703
function _extract_function(
1704-
code::String; func_name::String="main", func_op_kind::String="func.func"
1704+
code::String;
1705+
func_name::String="main",
1706+
func_op_kind::String="func.func",
1707+
nested_module::Bool=false,
17051708
)
17061709
module_suffix = string(hash(code); base=16)
17071710
name_to_call = _new_function_name(func_name, module_suffix)
17081711

17091712
current_module = MLIR.IR.mmodule()
1713+
if nested_module
1714+
new_module = MLIR.IR.Module()
1715+
push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true))
1716+
current_module = new_module
1717+
end
17101718
top_level_block = MLIR.IR.body(current_module)
17111719

17121720
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
@@ -1770,7 +1778,9 @@ function triton_call(
17701778
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
17711779
# TODO: other kwargs
17721780
)
1773-
_, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1781+
_, name_to_call = _extract_function(
1782+
mlir_code; func_name, func_op_kind="tt.func", nested_module=true
1783+
)
17741784

17751785
enzymexla.triton_call(
17761786
grid_x.mlir_data,

src/mlir/IR/Module.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false)
5252
5353
Views the module as a generic operation.
5454
"""
55-
Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false)
55+
Operation(module_::Module, owned::Bool=false) =
56+
Operation(API.mlirModuleGetOperation(module_), owned)
5657

5758
function Base.show(io::IO, module_::Module)
5859
return show(io, Operation(module_))

src/xla/XLA.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT)
234234
)
235235
state.clients["cuda"] = gpu
236236
state.default_client = gpu
237+
238+
# set values for cuda. This is being done here since we need cuda
239+
# to be initialized before we can use it. initializing the devices
240+
# implicitly initializes cuda.
241+
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
242+
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
243+
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"
244+
245+
Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
237246
catch e
238247
println(stdout, e)
239248
end

0 commit comments

Comments
 (0)