@@ -47,8 +47,14 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747 return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
4848end
4949
50- normalize_grid_and_blocks (grid:: Integer ) = normalize_grid_and_blocks ((grid,))
51- function normalize_grid_and_blocks (grid:: Dims{N} ) where {N}
50+ function normalize_grid_and_blocks (grid_fn, metadata)
51+ return normalize_grid_and_blocks (grid_fn (metadata), metadata)
52+ end
53+
54+ function normalize_grid_and_blocks (grid:: Integer , metadata)
55+ return normalize_grid_and_blocks ((grid,), metadata)
56+ end
57+ function normalize_grid_and_blocks (grid:: Dims{N} , metadata) where {N}
5258 @assert N <= 3
5359 @assert all (grid .> 0 )
5460 return (grid... , ntuple (_ -> 1 , 3 - N)... )
@@ -71,9 +77,6 @@ function overlayed_pycall_with_triton(
7177)
7278 triton = tritonptr[]
7379
74- grid = normalize_grid_and_blocks (grid)
75- blocks = normalize_grid_and_blocks (blocks)
76-
7780 mapped = map (signature_string, args)
7881 signature = first .(mapped)
7982 # TODO : are hints actually correctly set?
@@ -120,6 +123,9 @@ function overlayed_pycall_with_triton(
120123 # we are compiling here + lowering again inside enzymejax
121124 ccinfo = triton. compile (src; target= target, options= options. __dict__)
122125
126+ grid = normalize_grid_and_blocks (grid, ccinfo. metadata)
127+ blocks = normalize_grid_and_blocks (blocks, ccinfo. metadata)
128+
123129 return @opcall triton_call (
124130 pyconvert (String, ccinfo. asm[" source" ]),
125131 filter (x -> x isa Reactant. TracedType, args)... ;
0 commit comments