Skip to content

Commit 0bb32b8

Browse files
committed
feat: initial triton setup [skip ci]
feat: auto-trace triton code feat: copy tt.func into main module [skip ci] feat: tracing fully functional fix: hlo_call feat: more triton passes + keep triton func in a separate module feat: put the tt func in a separate module and use symbol ref feat: new triton_ext dialect feat: triton tracing works now finally fix: kind of working fix: new API feat: return values feat: lowering triton now works feat: triton working end to end fix: extra export + naming feat: allow grid/blocks via a function [skip ci] feat: use new device properties [skip ci] feat: correctly set strides + get n_regs test: add some triton tests test: layer_norm + libdevice fix: partial fix to the blocks fix: correct launch configuration test: missing vars chore: bump workspace fix: cluster dims fix: bump version chore: bump
1 parent 41028a3 commit 0bb32b8

23 files changed

+1376
-82
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4"
55
jax = ">= 0.6"
66
tensorflow = ">= 2.17"
77
numpy = ">= 2"
8+
triton = ">= 3.4"

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
module ReactantPythonCallExt
22

3-
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
3+
using PythonCall:
4+
PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple
45
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
56
using Reactant.Ops: @opcall
7+
using Reactant_jll: Reactant_jll
68

79
const jaxptr = Ref{Py}()
810
const jnpptr = Ref{Py}()
911

1012
const JAX_TRACING_SUPPORTED = Ref{Bool}(false)
1113

14+
const tritonptr = Ref{Py}()
15+
16+
const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false)
17+
1218
const tfptr = Ref{Py}()
1319
const tf2xlaptr = Ref{Py}()
1420
const npptr = Ref{Py}()
@@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict(
3339
ComplexF64 => :complex64,
3440
)
3541

42+
const MLIR_TYPE_STRING = Dict(
43+
Float64 => "fp64",
44+
Float32 => "fp32",
45+
Float16 => "fp16",
46+
Int64 => "i64",
47+
Int32 => "i32",
48+
Int16 => "i16",
49+
Int8 => "i8",
50+
UInt64 => "ui64",
51+
UInt32 => "ui32",
52+
UInt16 => "ui16",
53+
UInt8 => "ui8",
54+
Bool => "i1",
55+
Reactant.F8E4M3FN => "fp8e4nv",
56+
Reactant.F8E5M2FNUZ => "fp8e5b16",
57+
Reactant.F8E4M3FNUZ => "fp8e4b8",
58+
Reactant.F8E5M2 => "fp8e5",
59+
)
60+
if isdefined(Core, :BFloat16)
61+
MLIR_TYPE_STRING[Core.BFloat16] = "bf16"
62+
end
63+
3664
function __init__()
3765
try
3866
jaxptr[] = pyimport("jax")
@@ -43,6 +71,14 @@ function __init__()
4371
be supported." exception = (err, catch_backtrace())
4472
end
4573

74+
try
75+
tritonptr[] = pyimport("triton")
76+
TRITON_COMPILE_SUPPORTED[] = true
77+
catch err
78+
@warn "Failed to import triton. Compiling jax functions with triton won't be \
79+
supported." exception = (err, catch_backtrace())
80+
end
81+
4682
try
4783
tfptr[] = pyimport("tensorflow")
4884
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
@reactant_overlay function PythonCall.pycall(f::Py, args...)
1+
@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...)
22
if Reactant.looped_any(Reactant.use_overlayed_version, args)
3-
return pycall_with_jax_tracing(f, args...)
3+
return overlayed_pycall(f, args...; kwargs...)
44
else
5-
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
5+
return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...)
66
end
77
end

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
77
)
88
end
99

10-
function pycall_with_jax_tracing(f::Py, args...)
10+
function overlayed_pycall(f::Py, args...; kwargs...)
11+
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
12+
# TODO: check for Autotuner and Heutistics as well
13+
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
14+
return overlayed_pycall_with_triton(f, args...; kwargs...)
15+
else
16+
@assert isempty(kwargs) "`kwargs` are not supported for jax traced functions."
17+
return overlayed_pycall_with_jax_tracing(f, args...)
18+
end
19+
end
20+
21+
function overlayed_pycall_with_jax_tracing(f::Py, args...)
1122
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
1223

1324
seen_args = Reactant.OrderedIdDict()
@@ -35,3 +46,144 @@ function pycall_with_jax_tracing(f::Py, args...)
3546
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
3647
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
3748
end
49+
50+
struct TritonMetadata{CK,MD,DP}
51+
compiled_kernel::CK
52+
metadata::MD
53+
device_properties::DP
54+
num_warps::Int
55+
num_stages::Int
56+
num_ctas::Int
57+
num_regs::Int
58+
num_spills::Int
59+
max_num_threads::Int
60+
end
61+
62+
canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata)
63+
canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata)
64+
function canonicalize_grid(grid::Dims{N}, metadata) where {N}
65+
@assert N <= 3
66+
@assert all(grid .> 0)
67+
return (grid..., ntuple(_ -> 1, 3 - N)...)
68+
end
69+
70+
signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing
71+
signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing
72+
signature_string(x::T) where {T<:Number} = string(x), x
73+
signature_string(x) = error("Unsupported argument type: $(typeof(x))")
74+
75+
# TODO: better name for hints?
76+
function overlayed_pycall_with_triton(
77+
kernel::Py,
78+
args...;
79+
grid,
80+
num_warps::Integer=4,
81+
num_stages::Integer=3,
82+
num_ctas::Integer=1,
83+
hints=nothing,
84+
)
85+
@assert num_ctas == 1 "TODO: num_ctas > 1 not supported"
86+
triton = tritonptr[]
87+
88+
mapped = map(signature_string, args)
89+
signature = first.(mapped)
90+
# TODO: are hints actually correctly set?
91+
hints =
92+
hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints)
93+
constants = Dict(
94+
kernel.arg_names[i - 1] => constant for
95+
(i, constant) in enumerate(last.(mapped)) if constant !== nothing
96+
)
97+
for (k, v) in hints
98+
v == 1 && (constants[kernel.arg_names[k - 1]] = v)
99+
end
100+
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)
101+
102+
sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature))
103+
for k in keys(constants)
104+
sigmap[k] = "constexpr"
105+
end
106+
107+
for h in values(hints)
108+
@assert h in (1, 16) "Only 1 and 16 are valid hints, got $h"
109+
end
110+
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)
111+
112+
src = triton.compiler.ASTSource(;
113+
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
114+
)
115+
116+
# TODO: pass the device/client here from `compile`
117+
# TODO: cluster dims
118+
client = Reactant.XLA.default_backend()
119+
@assert Reactant.XLA.platform_name(client) == "cuda"
120+
device = Reactant.XLA.default_device(client)
121+
device_properties = Reactant.XLA.device_properties(device)
122+
123+
target = triton.backends.compiler.GPUTarget(
124+
Reactant.XLA.platform_name(client),
125+
parse(Int, "$(device_properties.major)$(device_properties.minor)"),
126+
device_properties.warp_size,
127+
)
128+
backend = triton.compiler.make_backend(target)
129+
options = backend.parse_options(
130+
pydict(
131+
"num_warps" => num_warps,
132+
"num_stages" => num_stages,
133+
"num_ctas" => num_ctas,
134+
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
135+
),
136+
)
137+
138+
# Currently we are doing a double compilation here. can we do better?
139+
# we are compiling here + lowering again inside enzymejax
140+
compiled_kernel = triton.compile(src; target=target, options=options.__dict__)
141+
142+
cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"])
143+
fname = pyconvert(String, compiled_kernel.metadata.name)
144+
n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}()
145+
GC.@preserve cubin fname n_regs n_spills n_max_threads begin
146+
@ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
147+
cubin::Ptr{Cvoid},
148+
fname::Cstring,
149+
n_regs::Ptr{Int32},
150+
n_spills::Ptr{Int32},
151+
n_max_threads::Ptr{Int32},
152+
)::Cvoid
153+
end
154+
155+
metadata = TritonMetadata(
156+
compiled_kernel,
157+
compiled_kernel.metadata,
158+
device_properties,
159+
num_warps,
160+
num_stages,
161+
num_ctas,
162+
Int(n_regs[]),
163+
Int(n_spills[]),
164+
Int(n_max_threads[]),
165+
)
166+
167+
grid = canonicalize_grid(grid, metadata)
168+
169+
# TODO: actual cluster_x/y/z
170+
171+
return @opcall triton_call(
172+
pyconvert(String, compiled_kernel.asm["source"]),
173+
filter(x -> x isa Reactant.TracedType, args)...;
174+
func_name=fname,
175+
grid_x=@opcall(constant(grid[1])),
176+
grid_y=@opcall(constant(grid[2])),
177+
grid_z=@opcall(constant(grid[3])),
178+
block_x=@opcall(constant(num_warps * device_properties.warp_size)),
179+
block_y=@opcall(constant(1)),
180+
block_z=@opcall(constant(1)),
181+
cluster_x=@opcall(constant(1)),
182+
cluster_y=@opcall(constant(1)),
183+
cluster_z=@opcall(constant(1)),
184+
num_ctas,
185+
num_warps,
186+
threads_per_warp=device_properties.warp_size,
187+
enable_source_remat=false,
188+
)
189+
end

src/CompileOptions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ function CompileOptions(;
229229
:canonicalize,
230230
:just_batch,
231231
:none,
232+
:no_triton,
233+
:before_triton_lowering,
232234
]
233235
end
234236

0 commit comments

Comments
 (0)