@@ -59,9 +59,9 @@ struct TritonMetadata{CK,MD,DP}
5959 max_num_threads:: Int
6060end
6161
62- normalize_grid (grid_fn, metadata) = normalize_grid (grid_fn (metadata), metadata)
63- normalize_grid (grid:: Integer , metadata) = normalize_grid ((grid,), metadata)
64- function normalize_grid (grid:: Dims{N} , metadata) where {N}
62+ canonicalize_grid (grid_fn, metadata) = canonicalize_grid (grid_fn (metadata), metadata)
63+ canonicalize_grid (grid:: Integer , metadata) = canonicalize_grid ((grid,), metadata)
64+ function canonicalize_grid (grid:: Dims{N} , metadata) where {N}
6565 @assert N <= 3
6666 @assert all (grid .> 0 )
6767 return (grid... , ntuple (_ -> 1 , 3 - N)... )
@@ -82,6 +82,7 @@ function overlayed_pycall_with_triton(
8282 num_ctas:: Integer = 1 ,
8383 hints= nothing ,
8484)
85+ @assert num_ctas == 1 " TODO: num_ctas > 1 not supported"
8586 triton = tritonptr[]
8687
8788 mapped = map (signature_string, args)
@@ -163,7 +164,7 @@ function overlayed_pycall_with_triton(
163164 Int (n_max_threads[]),
164165 )
165166
166- grid = normalize_grid (grid, metadata)
167+ grid = canonicalize_grid (grid, metadata)
167168
168169 return @opcall triton_call (
169170 pyconvert (String, compiled_kernel. asm[" source" ]),
@@ -177,5 +178,7 @@ function overlayed_pycall_with_triton(
177178 block_z= @opcall (constant (1 )),
178179 num_ctas,
179180 num_warps,
181+ threads_per_warp= device_properties. warp_size,
182+ enable_source_remat= false ,
180183 )
181184end
0 commit comments