@@ -47,16 +47,25 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747 return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
4848end
4949
50- function normalize_grid_and_blocks (grid_fn, metadata, device_properties)
51- return normalize_grid_and_blocks (
52- grid_fn (metadata, device_properties), metadata, device_properties
53- )
50+ struct TritonMetadata{CK,MD,DP}
51+ compiled_kernel:: CK
52+ metadata:: MD
53+ device_properties:: DP
54+ num_warps:: Int
55+ num_stages:: Int
56+ num_ctas:: Int
57+ num_regs:: Int
58+ num_spills:: Int
59+ max_num_threads:: Int
5460end
5561
56- function normalize_grid_and_blocks (grid:: Integer , metadata, device_properties)
57- return normalize_grid_and_blocks ((grid,), metadata, device_properties)
62+ function normalize_grid_and_blocks (grid_fn, metadata)
63+ return normalize_grid_and_blocks (grid_fn (metadata), metadata)
64+ end
65+ function normalize_grid_and_blocks (grid:: Integer , metadata)
66+ return normalize_grid_and_blocks ((grid,), metadata)
5867end
59- function normalize_grid_and_blocks (grid:: Dims{N} , metadata, device_properties ) where {N}
68+ function normalize_grid_and_blocks (grid:: Dims{N} , metadata) where {N}
6069 @assert N <= 3
6170 @assert all (grid .> 0 )
6271 return (grid... , ntuple (_ -> 1 , 3 - N)... )
@@ -131,15 +140,40 @@ function overlayed_pycall_with_triton(
131140
132141 # Currently we are doing a double compilation here. can we do better?
133142 # we are compiling here + lowering again inside enzymejax
134- ccinfo = triton. compile (src; target= target, options= options. __dict__)
143+ compiled_kernel = triton. compile (src; target= target, options= options. __dict__)
144+
145+ cubin = pyconvert (Vector{UInt8}, compiled_kernel. asm[" cubin" ])
146+ fname = pyconvert (String, compiled_kernel. metadata. name)
147+ n_regs, n_spills, n_max_threads = Ref {Int32} (), Ref {Int32} (), Ref {Int32} ()
148+ GC. @preserve cubin fname n_regs n_spills n_max_threads begin
149+ @ccall Reactant. MLIR. API. mlir_c. ReactantCudaGetRegsSpillsMaxThreadsFromBinary (
150+ cubin:: Ptr{Cvoid} ,
151+ fname:: Cstring ,
152+ n_regs:: Ptr{Int32} ,
153+ n_spills:: Ptr{Int32} ,
154+ n_max_threads:: Ptr{Int32} ,
155+ ):: Cvoid
156+ end
157+
158+ metadata = TritonMetadata (
159+ compiled_kernel,
160+ compiled_kernel. metadata,
161+ device_properties,
162+ num_warps,
163+ num_stages,
164+ num_ctas,
165+ Int (n_regs[]),
166+ Int (n_spills[]),
167+ Int (n_max_threads[]),
168+ )
135169
136- grid = normalize_grid_and_blocks (grid, ccinfo . metadata, device_properties )
137- blocks = normalize_grid_and_blocks (blocks, ccinfo . metadata, device_properties )
170+ grid = normalize_grid_and_blocks (grid, metadata)
171+ blocks = normalize_grid_and_blocks (blocks, metadata)
138172
139173 return @opcall triton_call (
140- pyconvert (String, ccinfo . asm[" source" ]),
174+ pyconvert (String, compiled_kernel . asm[" source" ]),
141175 filter (x -> x isa Reactant. TracedType, args)... ;
142- func_name= pyconvert (String, ccinfo . metadata . name) ,
176+ func_name= fname ,
143177 grid_x= @opcall (constant (grid[1 ])),
144178 grid_y= @opcall (constant (grid[2 ])),
145179 grid_z= @opcall (constant (grid[3 ])),
0 commit comments