Skip to content

Commit 827fbf4

Browse files
committed
feat: allow grid/blocks via a function [skip ci]
1 parent 57c4c2c commit 827fbf4

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
51-
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
50+
function normalize_grid_and_blocks(grid_fn, metadata)
51+
return normalize_grid_and_blocks(grid_fn(metadata), metadata)
52+
end
53+
54+
function normalize_grid_and_blocks(grid::Integer, metadata)
55+
return normalize_grid_and_blocks((grid,), metadata)
56+
end
57+
function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N}
5258
@assert N <= 3
5359
@assert all(grid .> 0)
5460
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -71,9 +77,6 @@ function overlayed_pycall_with_triton(
7177
)
7278
triton = tritonptr[]
7379

74-
grid = normalize_grid_and_blocks(grid)
75-
blocks = normalize_grid_and_blocks(blocks)
76-
7780
mapped = map(signature_string, args)
7881
signature = first.(mapped)
7982
# TODO: are hints actually correctly set?
@@ -120,6 +123,9 @@ function overlayed_pycall_with_triton(
120123
# we are compiling here + lowering again inside enzymejax
121124
ccinfo = triton.compile(src; target=target, options=options.__dict__)
122125

126+
grid = normalize_grid_and_blocks(grid, ccinfo.metadata)
127+
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata)
128+
123129
return @opcall triton_call(
124130
pyconvert(String, ccinfo.asm["source"]),
125131
filter(x -> x isa Reactant.TracedType, args)...;

0 commit comments

Comments
 (0)