Skip to content

Commit 90776d2

Browse files
committed
feat: initial triton setup [skip ci]
1 parent 2d0e0e3 commit 90776d2

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
jax = ">= 0.6"
33
tensorflow = ">= 2.17"
44
numpy = ">= 2"
5+
triton = "" # TODO: version bound

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ReactantPythonCallExt
22

3-
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
3+
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance
44
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
55
using Reactant.Ops: @opcall
66

@@ -9,6 +9,10 @@ const jnpptr = Ref{Py}()
99

1010
const JAX_TRACING_SUPPORTED = Ref{Bool}(false)
1111

12+
const tritonptr = Ref{Py}()
13+
14+
const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false)
15+
1216
const tfptr = Ref{Py}()
1317
const tf2xlaptr = Ref{Py}()
1418
const npptr = Ref{Py}()
@@ -43,6 +47,14 @@ function __init__()
4347
be supported." exception = (err, catch_backtrace())
4448
end
4549

50+
try
51+
tritonptr[] = pyimport("triton")
52+
TRITON_COMPILE_SUPPORTED[] = true
53+
catch err
54+
@warn "Failed to import triton. Compiling jax functions with triton won't be \
55+
supported." exception = (err, catch_backtrace())
56+
end
57+
4658
try
4759
tfptr[] = pyimport("tensorflow")
4860
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")

ext/ReactantPythonCallExt/overlays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@reactant_overlay function PythonCall.pycall(f::Py, args...)
22
if Reactant.looped_any(Reactant.use_overlayed_version, args)
3-
return pycall_with_jax_tracing(f, args...)
3+
return overlayed_pycall(f, args...)
44
else
55
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
66
end

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
77
)
88
end
99

10-
function pycall_with_jax_tracing(f::Py, args...)
10+
function overlayed_pycall(f::Py, args...)
11+
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
12+
# TODO: check for Autotuner and Heutistics as well
13+
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
14+
return overlayed_pycall_with_triton(f, args...)
15+
else
16+
return overlayed_pycall_with_jax_tracing(f, args...)
17+
end
18+
end
19+
20+
function overlayed_pycall_with_jax_tracing(f::Py, args...)
1121
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
1222

1323
seen_args = Reactant.OrderedIdDict()
@@ -35,3 +45,7 @@ function pycall_with_jax_tracing(f::Py, args...)
3545
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
3646
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
3747
end
48+
49+
function overlayed_pycall_with_triton(f::Py, args...)
50+
error("TODO: implement triton")
51+
end

0 commit comments

Comments
 (0)