Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4"
jax = ">= 0.6"
tensorflow = ">= 2.17"
numpy = ">= 2"
triton = ">= 3.4"
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "47d57c1cea7b24e210ad75aee6e7c3f93d89ff78"
ENZYMEXLA_COMMIT = "defe9ed6f939cc22a7715f2b8c98a39d9e51e2c9"

ENZYMEXLA_SHA256 = ""

Expand Down
38 changes: 37 additions & 1 deletion ext/ReactantPythonCallExt/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
module ReactantPythonCallExt

using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
using PythonCall:
PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
using Reactant.Ops: @opcall
using Reactant_jll: Reactant_jll

const jaxptr = Ref{Py}()
const jnpptr = Ref{Py}()

const JAX_TRACING_SUPPORTED = Ref{Bool}(false)

const tritonptr = Ref{Py}()

const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false)

const tfptr = Ref{Py}()
const tf2xlaptr = Ref{Py}()
const npptr = Ref{Py}()
Expand All @@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict(
ComplexF64 => :complex64,
)

const MLIR_TYPE_STRING = Dict(
Float64 => "fp64",
Float32 => "fp32",
Float16 => "fp16",
Int64 => "i64",
Int32 => "i32",
Int16 => "i16",
Int8 => "i8",
UInt64 => "ui64",
UInt32 => "ui32",
UInt16 => "ui16",
UInt8 => "ui8",
Bool => "i1",
Reactant.F8E4M3FN => "fp8e4nv",
Reactant.F8E5M2FNUZ => "fp8e5b16",
Reactant.F8E4M3FNUZ => "fp8e4b8",
Reactant.F8E5M2 => "fp8e5",
)
if isdefined(Core, :BFloat16)
MLIR_TYPE_STRING[Core.BFloat16] = "bf16"
end

function __init__()
try
jaxptr[] = pyimport("jax")
Expand All @@ -43,6 +71,14 @@ function __init__()
be supported." exception = (err, catch_backtrace())
end

try
tritonptr[] = pyimport("triton")
TRITON_COMPILE_SUPPORTED[] = true
catch err
@warn "Failed to import triton. Compiling jax functions with triton won't be \
supported." exception = (err, catch_backtrace())
end

try
tfptr[] = pyimport("tensorflow")
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
Expand Down
6 changes: 3 additions & 3 deletions ext/ReactantPythonCallExt/overlays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@reactant_overlay function PythonCall.pycall(f::Py, args...)
@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...)
if Reactant.looped_any(Reactant.use_overlayed_version, args)
return pycall_with_jax_tracing(f, args...)
return overlayed_pycall(f, args...; kwargs...)
else
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...)
end
end
154 changes: 153 additions & 1 deletion ext/ReactantPythonCallExt/pycall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
)
end

function pycall_with_jax_tracing(f::Py, args...)
function overlayed_pycall(f::Py, args...; kwargs...)
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
# TODO: check for Autotuner and Heutistics as well
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
return overlayed_pycall_with_triton(f, args...; kwargs...)
else
@assert isempty(kwargs) "`kwargs` are not supported for jax traced functions."
return overlayed_pycall_with_jax_tracing(f, args...)
end
end

function overlayed_pycall_with_jax_tracing(f::Py, args...)
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")

seen_args = Reactant.OrderedIdDict()
Expand Down Expand Up @@ -35,3 +46,144 @@ function pycall_with_jax_tracing(f::Py, args...)
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
end

struct TritonMetadata{CK,MD,DP}
compiled_kernel::CK
metadata::MD
device_properties::DP
num_warps::Int
num_stages::Int
num_ctas::Int
num_regs::Int
num_spills::Int
max_num_threads::Int
end

canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata)
canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata)
function canonicalize_grid(grid::Dims{N}, metadata) where {N}
@assert N <= 3
@assert all(grid .> 0)
return (grid..., ntuple(_ -> 1, 3 - N)...)
end

signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing
signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing
signature_string(x::T) where {T<:Number} = string(x), x
signature_string(x) = error("Unsupported argument type: $(typeof(x))")

# TODO: better name for hints?
function overlayed_pycall_with_triton(
kernel::Py,
args...;
grid,
num_warps::Integer=4,
num_stages::Integer=3,
num_ctas::Integer=1,
hints=nothing,
)
@assert num_ctas == 1 "TODO: num_ctas > 1 not supported"
triton = tritonptr[]

mapped = map(signature_string, args)
signature = first.(mapped)
# TODO: are hints actually correctly set?
hints =
hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints)
constants = Dict(
kernel.arg_names[i - 1] => constant for
(i, constant) in enumerate(last.(mapped)) if constant !== nothing
)
for (k, v) in hints
v == 1 && (constants[kernel.arg_names[k - 1]] = v)
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature))
for k in keys(constants)
sigmap[k] = "constexpr"
end

for h in values(hints)
@assert h in (1, 16) "Only 1 and 16 are valid hints, got $h"
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

src = triton.compiler.ASTSource(;
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
)

# TODO: pass the device/client here from `compile`
# TODO: cluster dims
client = Reactant.XLA.default_backend()
@assert Reactant.XLA.platform_name(client) == "cuda"
device = Reactant.XLA.default_device(client)
device_properties = Reactant.XLA.device_properties(device)

target = triton.backends.compiler.GPUTarget(
Reactant.XLA.platform_name(client),
parse(Int, "$(device_properties.major)$(device_properties.minor)"),
device_properties.warp_size,
)
backend = triton.compiler.make_backend(target)
options = backend.parse_options(
pydict(
"num_warps" => num_warps,
"num_stages" => num_stages,
"num_ctas" => num_ctas,
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
),
)

# Currently we are doing a double compilation here. can we do better?
# we are compiling here + lowering again inside enzymejax
compiled_kernel = triton.compile(src; target=target, options=options.__dict__)

cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"])
fname = pyconvert(String, compiled_kernel.metadata.name)
n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}()
GC.@preserve cubin fname n_regs n_spills n_max_threads begin
@ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
cubin::Ptr{Cvoid},
fname::Cstring,
n_regs::Ptr{Int32},
n_spills::Ptr{Int32},
n_max_threads::Ptr{Int32},
)::Cvoid
end

metadata = TritonMetadata(
compiled_kernel,
compiled_kernel.metadata,
device_properties,
num_warps,
num_stages,
num_ctas,
Int(n_regs[]),
Int(n_spills[]),
Int(n_max_threads[]),
)

grid = canonicalize_grid(grid, metadata)

# TODO: actual cluster_x/y/z

return @opcall triton_call(
pyconvert(String, compiled_kernel.asm["source"]),
filter(x -> x isa Reactant.TracedType, args)...;
func_name=fname,
grid_x=@opcall(constant(grid[1])),
grid_y=@opcall(constant(grid[2])),
grid_z=@opcall(constant(grid[3])),
block_x=@opcall(constant(num_warps * device_properties.warp_size)),
block_y=@opcall(constant(1)),
block_z=@opcall(constant(1)),
cluster_x=@opcall(constant(1)),
cluster_y=@opcall(constant(1)),
cluster_z=@opcall(constant(1)),
num_ctas,
num_warps,
threads_per_warp=device_properties.warp_size,
enable_source_remat=false,
)
end
2 changes: 2 additions & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:no_triton,
:before_triton_lowering,
]
end

Expand Down
Loading