@@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
77 )
88end
99
10- function pycall_with_jax_tracing (f:: Py , args... )
10+ function overlayed_pycall (f:: Py , args... ; kwargs... )
11+ @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
12+ # TODO : check for Autotuner and Heutistics as well
13+ if TRITON_COMPILE_SUPPORTED[] && pyisinstance (f, tritonptr[]. JITFunction)
14+ return overlayed_pycall_with_triton (f, args... ; kwargs... )
15+ else
16+ @assert isempty (kwargs) " `kwargs` are not supported for jax traced functions."
17+ return overlayed_pycall_with_jax_tracing (f, args... )
18+ end
19+ end
20+
21+ function overlayed_pycall_with_jax_tracing (f:: Py , args... )
1122 JAX_TRACING_SUPPORTED[] || throw (" jax could not be loaded." )
1223
1324 seen_args = Reactant. OrderedIdDict ()
@@ -35,3 +46,144 @@ function pycall_with_jax_tracing(f::Py, args...)
3546 res = @opcall hlo_call (pyconvert (String, lowered. as_text ()), linear_args... )
3647 return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
3748end
49+
50+ struct TritonMetadata{CK,MD,DP}
51+ compiled_kernel:: CK
52+ metadata:: MD
53+ device_properties:: DP
54+ num_warps:: Int
55+ num_stages:: Int
56+ num_ctas:: Int
57+ num_regs:: Int
58+ num_spills:: Int
59+ max_num_threads:: Int
60+ end
61+
62+ canonicalize_grid (grid_fn, metadata) = canonicalize_grid (grid_fn (metadata), metadata)
63+ canonicalize_grid (grid:: Integer , metadata) = canonicalize_grid ((grid,), metadata)
64+ function canonicalize_grid (grid:: Dims{N} , metadata) where {N}
65+ @assert N <= 3
66+ @assert all (grid .> 0 )
67+ return (grid... , ntuple (_ -> 1 , 3 - N)... )
68+ end
69+
70+ signature_string (:: TracedRArray{T} ) where {T} = " *$(MLIR_TYPE_STRING[T]) " , nothing
71+ signature_string (:: TracedRNumber{T} ) where {T} = " $(MLIR_TYPE_STRING[T]) " , nothing
72+ signature_string (x:: T ) where {T<: Number } = string (x), x
73+ signature_string (x) = error (" Unsupported argument type: $(typeof (x)) " )
74+
75+ # TODO : better name for hints?
76+ function overlayed_pycall_with_triton (
77+ kernel:: Py ,
78+ args... ;
79+ grid,
80+ num_warps:: Integer = 4 ,
81+ num_stages:: Integer = 3 ,
82+ num_ctas:: Integer = 1 ,
83+ hints= nothing ,
84+ )
85+ @assert num_ctas == 1 " TODO: num_ctas > 1 not supported"
86+ triton = tritonptr[]
87+
88+ mapped = map (signature_string, args)
89+ signature = first .(mapped)
90+ # TODO : are hints actually correctly set?
91+ hints =
92+ hints === nothing ? Dict () : Dict (kernel. arg_names[i - 1 ] => v for (i, v) in hints)
93+ constants = Dict (
94+ kernel. arg_names[i - 1 ] => constant for
95+ (i, constant) in enumerate (last .(mapped)) if constant != = nothing
96+ )
97+ for (k, v) in hints
98+ v == 1 && (constants[kernel. arg_names[k - 1 ]] = v)
99+ end
100+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
101+
102+ sigmap = Dict (kernel. arg_names[i - 1 ] => sig for (i, sig) in enumerate (signature))
103+ for k in keys (constants)
104+ sigmap[k] = " constexpr"
105+ end
106+
107+ for h in values (hints)
108+ @assert h in (1 , 16 ) " Only 1 and 16 are valid hints, got $h "
109+ end
110+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
111+
112+ src = triton. compiler. ASTSource (;
113+ fn= kernel, constexprs= constants, signature= sigmap, attrs= attrs
114+ )
115+
116+ # TODO : pass the device/client here from `compile`
117+ # TODO : cluster dims
118+ client = Reactant. XLA. default_backend ()
119+ @assert Reactant. XLA. platform_name (client) == " cuda"
120+ device = Reactant. XLA. default_device (client)
121+ device_properties = Reactant. XLA. device_properties (device)
122+
123+ target = triton. backends. compiler. GPUTarget (
124+ Reactant. XLA. platform_name (client),
125+ parse (Int, " $(device_properties. major)$(device_properties. minor) " ),
126+ device_properties. warp_size,
127+ )
128+ backend = triton. compiler. make_backend (target)
129+ options = backend. parse_options (
130+ pydict (
131+ " num_warps" => num_warps,
132+ " num_stages" => num_stages,
133+ " num_ctas" => num_ctas,
134+ " extern_libs" => pytuple ((pytuple ((" libdevice" , Reactant_jll. libdevice)),)),
135+ ),
136+ )
137+
138+ # Currently we are doing a double compilation here. can we do better?
139+ # we are compiling here + lowering again inside enzymejax
140+ compiled_kernel = triton. compile (src; target= target, options= options. __dict__)
141+
142+ cubin = pyconvert (Vector{UInt8}, compiled_kernel. asm[" cubin" ])
143+ fname = pyconvert (String, compiled_kernel. metadata. name)
144+ n_regs, n_spills, n_max_threads = Ref {Int32} (), Ref {Int32} (), Ref {Int32} ()
145+ GC. @preserve cubin fname n_regs n_spills n_max_threads begin
146+ @ccall Reactant. MLIR. API. mlir_c. ReactantCudaGetRegsSpillsMaxThreadsFromBinary (
147+ cubin:: Ptr{Cvoid} ,
148+ fname:: Cstring ,
149+ n_regs:: Ptr{Int32} ,
150+ n_spills:: Ptr{Int32} ,
151+ n_max_threads:: Ptr{Int32} ,
152+ ):: Cvoid
153+ end
154+
155+ metadata = TritonMetadata (
156+ compiled_kernel,
157+ compiled_kernel. metadata,
158+ device_properties,
159+ num_warps,
160+ num_stages,
161+ num_ctas,
162+ Int (n_regs[]),
163+ Int (n_spills[]),
164+ Int (n_max_threads[]),
165+ )
166+
167+ grid = canonicalize_grid (grid, metadata)
168+
169+ # TODO : actual cluster_x/y/z
170+
171+ return @opcall triton_call (
172+ pyconvert (String, compiled_kernel. asm[" source" ]),
173+ filter (x -> x isa Reactant. TracedType, args)... ;
174+ func_name= fname,
175+ grid_x= @opcall (constant (grid[1 ])),
176+ grid_y= @opcall (constant (grid[2 ])),
177+ grid_z= @opcall (constant (grid[3 ])),
178+ block_x= @opcall (constant (num_warps * device_properties. warp_size)),
179+ block_y= @opcall (constant (1 )),
180+ block_z= @opcall (constant (1 )),
181+ cluster_x= @opcall (constant (1 )),
182+ cluster_y= @opcall (constant (1 )),
183+ cluster_z= @opcall (constant (1 )),
184+ num_ctas,
185+ num_warps,
186+ threads_per_warp= device_properties. warp_size,
187+ enable_source_remat= false ,
188+ )
189+ end
0 commit comments