@@ -7,12 +7,13 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
77 )
88end
99
10- function overlayed_pycall (f:: Py , args... )
10+ function overlayed_pycall (f:: Py , args... ; kwargs ... )
1111 @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
1212 # TODO : check for Autotuner and Heutistics as well
1313 if TRITON_COMPILE_SUPPORTED[] && pyisinstance (f, tritonptr[]. JITFunction)
14- return overlayed_pycall_with_triton (f, args... )
14+ return overlayed_pycall_with_triton (f, args... ; kwargs ... )
1515 else
16+ @assert isempty (kwargs) " `kwargs` are not supported for jax traced functions."
1617 return overlayed_pycall_with_jax_tracing (f, args... )
1718 end
1819end
@@ -46,6 +47,69 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4647 return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
4748end
4849
49- function overlayed_pycall_with_triton (f:: Py , args... )
50- error (" TODO: implement triton" )
50+ # TODO : support using metaparams here
51+ normalize_grid (grid:: Integer ) = normalize_grid ((grid,))
52+ function normalize_grid (grid:: Dims{N} ) where {N}
53+ @assert N <= 3
54+ @assert all (grid .> 0 )
55+ return (grid... , ntuple (_ -> 1 , 3 - N)... )
56+ end
57+
58+ signature_string (:: TracedRArray{T} ) where {T} = " *$(MLIR_TYPE_STRING[T]) " , nothing
59+ signature_string (:: TracedRNumber{T} ) where {T} = " $(MLIR_TYPE_STRING[T]) " , nothing
60+ signature_string (x:: T ) where {T<: Number } = string (x), x
61+ signature_string (x) = error (" Unsupported argument type: $(typeof (x)) " )
62+
63+ function overlayed_pycall_with_triton (
64+ kernel:: Py , args... ; grid, num_warps:: Integer = 1 , num_stages:: Integer = 3 , hints= nothing
65+ )
66+ triton = tritonptr[]
67+
68+ grid = normalize_grid (grid)
69+
70+ mapped = map (signature_string, args)
71+ signature = first .(mapped)
72+ # TODO : are hints actually correctly set?
73+ hints =
74+ hints === nothing ? Dict () : Dict (kernel. arg_names[i - 1 ] => v for (i, v) in hints)
75+ constants = Dict (
76+ kernel. arg_names[i - 1 ] => constant for
77+ (i, constant) in enumerate (last .(mapped)) if constant != = nothing
78+ )
79+ for (k, v) in hints
80+ v == 1 && (constants[kernel. arg_names[k - 1 ]] = v)
81+ end
82+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
83+
84+ sigmap = Dict (kernel. arg_names[i - 1 ] => sig for (i, sig) in enumerate (signature))
85+ for k in keys (constants)
86+ sigmap[k] = " constexpr"
87+ end
88+
89+ for h in values (hints)
90+ @assert h in (1 , 16 ) " Only 1 and 16 are valid hints, got $h "
91+ end
92+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
93+
94+ src = triton. compiler. ASTSource (;
95+ fn= kernel, constexprs= constants, signature= sigmap, attrs= attrs
96+ )
97+
98+ # TODO : check that we are using CUDA. Get compute_capability from the target
99+ target = triton. backends. compiler. GPUTarget (" cuda" , 80 , 32 )
100+ backend = triton. compiler. make_backend (target)
101+ options = backend. parse_options (
102+ pydict (
103+ " num_warps" => num_warps,
104+ " num_stages" => num_stages,
105+ " extern_libs" => pytuple ((pytuple ((" libdevice" , Reactant_jll. libdevice)),)),
106+ ),
107+ )
108+
109+ ccinfo = triton. compile (src; target= target, options= options. __dict__)
110+
111+ println (pyconvert (String, ccinfo. asm[" source" ]))
112+ println (pyconvert (String, ccinfo. asm[" ttir" ]))
113+
114+ return error (" TODO: implement triton" )
51115end
0 commit comments