Skip to content

Commit 954f257

Browse files
committed
feat: correctly set strides + get n_regs
1 parent 735b3a9 commit 954f257

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,26 @@ REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
816816
jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor;
817817
}
818818

819+
REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
820+
const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
821+
int32_t *maxThreads) {
822+
CUfunction fun;
823+
CUmodule mod;
824+
825+
ReactantHandleCuResult(cuModuleLoadData(&mod, binary));
826+
ReactantHandleCuResult(cuModuleGetFunction(&fun, mod, fnname));
827+
828+
ReactantHandleCuResult(
829+
cuFuncGetAttribute(regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
830+
ReactantHandleCuResult(
831+
cuFuncGetAttribute(spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
832+
*spills /= 4;
833+
ReactantHandleCuResult(cuFuncGetAttribute(
834+
maxThreads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
835+
836+
return;
837+
}
838+
819839
#else
820840

821841
REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
@@ -831,6 +851,10 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }
831851
REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
832852
int32_t device_id) {}
833853

854+
REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
855+
const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
856+
int32_t *maxThreads) {}
857+
834858
#endif
835859

836860
REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,7 @@ cc_library(
984984
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
985985
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
986986
"-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties",
987+
"-Wl,-exported_symbol,_ReactantCudaGetRegsSpillsMaxThreadsFromBinary",
987988
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
988989
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
989990
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,25 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
function normalize_grid_and_blocks(grid_fn, metadata, device_properties)
51-
return normalize_grid_and_blocks(
52-
grid_fn(metadata, device_properties), metadata, device_properties
53-
)
50+
struct TritonMetadata{CK,MD,DP}
51+
compiled_kernel::CK
52+
metadata::MD
53+
device_properties::DP
54+
num_warps::Int
55+
num_stages::Int
56+
num_ctas::Int
57+
num_regs::Int
58+
num_spills::Int
59+
max_num_threads::Int
5460
end
5561

56-
function normalize_grid_and_blocks(grid::Integer, metadata, device_properties)
57-
return normalize_grid_and_blocks((grid,), metadata, device_properties)
62+
function normalize_grid_and_blocks(grid_fn, metadata)
63+
return normalize_grid_and_blocks(grid_fn(metadata), metadata)
64+
end
65+
function normalize_grid_and_blocks(grid::Integer, metadata)
66+
return normalize_grid_and_blocks((grid,), metadata)
5867
end
59-
function normalize_grid_and_blocks(grid::Dims{N}, metadata, device_properties) where {N}
68+
function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N}
6069
@assert N <= 3
6170
@assert all(grid .> 0)
6271
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -131,15 +140,40 @@ function overlayed_pycall_with_triton(
131140

132141
# Currently we are doing a double compilation here. can we do better?
133142
# we are compiling here + lowering again inside enzymejax
134-
ccinfo = triton.compile(src; target=target, options=options.__dict__)
143+
compiled_kernel = triton.compile(src; target=target, options=options.__dict__)
144+
145+
cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"])
146+
fname = pyconvert(String, compiled_kernel.metadata.name)
147+
n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}()
148+
GC.@preserve cubin fname n_regs n_spills n_max_threads begin
149+
@ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
150+
cubin::Ptr{Cvoid},
151+
fname::Cstring,
152+
n_regs::Ptr{Int32},
153+
n_spills::Ptr{Int32},
154+
n_max_threads::Ptr{Int32},
155+
)::Cvoid
156+
end
157+
158+
metadata = TritonMetadata(
159+
compiled_kernel,
160+
compiled_kernel.metadata,
161+
device_properties,
162+
num_warps,
163+
num_stages,
164+
num_ctas,
165+
Int(n_regs[]),
166+
Int(n_spills[]),
167+
Int(n_max_threads[]),
168+
)
135169

136-
grid = normalize_grid_and_blocks(grid, ccinfo.metadata, device_properties)
137-
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata, device_properties)
170+
grid = normalize_grid_and_blocks(grid, metadata)
171+
blocks = normalize_grid_and_blocks(blocks, metadata)
138172

139173
return @opcall triton_call(
140-
pyconvert(String, ccinfo.asm["source"]),
174+
pyconvert(String, compiled_kernel.asm["source"]),
141175
filter(x -> x isa Reactant.TracedType, args)...;
142-
func_name=pyconvert(String, ccinfo.metadata.name),
176+
func_name=fname,
143177
grid_x=@opcall(constant(grid[1])),
144178
grid_y=@opcall(constant(grid[2])),
145179
grid_z=@opcall(constant(grid[3])),

src/Reactant.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,35 @@ include("stdlibs/Base.jl")
236236
# Other Integrations
237237
include("Enzyme.jl")
238238

239+
"""
240+
rowmajor_strides(x::AbstractArray)
241+
242+
Returns the strides of the array `x` assuming that the array is stored in row-major order.
243+
"""
244+
rowmajor_strides(x::AbstractArray) = rowmajor_strides(size(x))
245+
function rowmajor_strides(sz::NTuple{N,Int}) where {N}
246+
strides = ntuple(_ -> 1, N)
247+
for i in (N - 1):-1:1
248+
strides = Base.setindex(strides, strides[i + 1] * sz[i + 1], i)
249+
end
250+
return strides
251+
end
252+
253+
"""
254+
rowmajor_stride(x::AbstractArray, i::Integer)
255+
256+
Returns the stride of the array `x` at dimension `i` assuming that the array is stored in
257+
row-major order.
258+
"""
259+
rowmajor_stride(x::AbstractArray, i::Integer) = rowmajor_stride(size(x), i)
260+
function rowmajor_stride(sz::NTuple{N,Int}, i::Integer) where {N}
261+
s = 1
262+
for j in (i + 1):N
263+
s *= sz[j]
264+
end
265+
return s
266+
end
267+
239268
export StackedBatchDuplicated, StackedBatchDuplicatedNoNeed
240269

241270
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

0 commit comments

Comments
 (0)