Skip to content

Commit b84d519

Browse files
committed
fix: new API
1 parent 97c952b commit b84d519

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
# TODO: support using metaparams here
51-
normalize_grid(grid::Integer) = normalize_grid((grid,))
52-
function normalize_grid(grid::Dims{N}) where {N}
50+
normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
51+
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
5352
@assert N <= 3
5453
@assert all(grid .> 0)
5554
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6261

6362
# TODO: better name for hints?
6463
function overlayed_pycall_with_triton(
65-
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
64+
kernel::Py,
65+
args...;
66+
grid,
67+
blocks,
68+
num_warps::Integer=1,
69+
num_stages::Integer=3,
70+
hints=nothing,
6671
)
6772
triton = tritonptr[]
6873

69-
grid = normalize_grid(grid)
74+
grid = normalize_grid_and_blocks(grid)
75+
blocks = normalize_grid_and_blocks(blocks)
7076

7177
mapped = map(signature_string, args)
7278
signature = first.(mapped)
@@ -121,7 +127,9 @@ function overlayed_pycall_with_triton(
121127
grid_x=@opcall(constant(grid[1])),
122128
grid_y=@opcall(constant(grid[2])),
123129
grid_z=@opcall(constant(grid[3])),
124-
shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))),
130+
block_x=@opcall(constant(blocks[1])),
131+
block_y=@opcall(constant(blocks[2])),
132+
block_z=@opcall(constant(blocks[3])),
125133
)
126134

127135
return nothing

src/Compiler.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,7 +1796,7 @@ function compile_mlir!(
17961796
backend,
17971797
enable_triton_passes=false,
17981798
)
1799-
opt_passes3 = optimization_passes(
1799+
opt_passes_with_triton = optimization_passes(
18001800
compile_options;
18011801
sroa=false,
18021802
recognize_comms,
@@ -1856,12 +1856,12 @@ function compile_mlir!(
18561856
"enzyme-batch",
18571857
opt_passes2,
18581858
enzyme_pass,
1859-
opt_passes3,
1859+
opt_passes_with_triton,
18601860
"canonicalize",
18611861
"remove-unnecessary-enzyme-ops",
18621862
"enzyme-simplify-math",
18631863
legalize_chlo_to_stablehlo...,
1864-
opt_passes3,
1864+
opt_passes2,
18651865
lower_enzymexla_linalg_pass,
18661866
jit,
18671867
]
@@ -1872,12 +1872,12 @@ function compile_mlir!(
18721872
"enzyme-batch",
18731873
opt_passes2,
18741874
enzyme_pass,
1875-
opt_passes3,
1875+
opt_passes_with_triton,
18761876
"canonicalize",
18771877
"remove-unnecessary-enzyme-ops",
18781878
"enzyme-simplify-math",
18791879
legalize_chlo_to_stablehlo...,
1880-
opt_passes3,
1880+
opt_passes2,
18811881
kern,
18821882
raise_passes,
18831883
lower_enzymexla_linalg_pass,
@@ -1901,12 +1901,12 @@ function compile_mlir!(
19011901
"enzyme-batch",
19021902
opt_passes2,
19031903
enzyme_pass,
1904-
opt_passes3,
1904+
opt_passes_with_triton,
19051905
"canonicalize",
19061906
"remove-unnecessary-enzyme-ops",
19071907
"enzyme-simplify-math",
19081908
legalize_chlo_to_stablehlo...,
1909-
opt_passes3,
1909+
opt_passes2,
19101910
]
19111911
end,
19121912
',',
@@ -1926,12 +1926,12 @@ function compile_mlir!(
19261926
"enzyme-batch",
19271927
opt_passes2,
19281928
enzyme_pass,
1929-
opt_passes3,
1929+
opt_passes_with_triton,
19301930
"canonicalize",
19311931
"remove-unnecessary-enzyme-ops",
19321932
"enzyme-simplify-math",
19331933
legalize_chlo_to_stablehlo...,
1934-
opt_passes3,
1934+
opt_passes2,
19351935
]
19361936
else
19371937
[
@@ -1940,12 +1940,12 @@ function compile_mlir!(
19401940
"enzyme-batch",
19411941
opt_passes2,
19421942
enzyme_pass,
1943-
opt_passes3,
1943+
opt_passes_with_triton,
19441944
"canonicalize",
19451945
"remove-unnecessary-enzyme-ops",
19461946
"enzyme-simplify-math",
19471947
legalize_chlo_to_stablehlo...,
1948-
opt_passes3,
1948+
opt_passes2,
19491949
kern,
19501950
raise_passes,
19511951
]
@@ -1967,12 +1967,12 @@ function compile_mlir!(
19671967
"enzyme-batch",
19681968
opt_passes2,
19691969
enzyme_pass,
1970-
opt_passes3,
1970+
opt_passes_with_triton,
19711971
"canonicalize",
19721972
"remove-unnecessary-enzyme-ops",
19731973
"enzyme-simplify-math",
19741974
legalize_chlo_to_stablehlo...,
1975-
opt_passes3,
1975+
opt_passes2,
19761976
kern,
19771977
]
19781978
end,
@@ -1990,12 +1990,12 @@ function compile_mlir!(
19901990
"enzyme-batch",
19911991
opt_passes2,
19921992
enzyme_pass,
1993-
opt_passes3,
1993+
opt_passes_with_triton,
19941994
"canonicalize",
19951995
"remove-unnecessary-enzyme-ops",
19961996
"enzyme-simplify-math",
19971997
legalize_chlo_to_stablehlo...,
1998-
opt_passes3,
1998+
opt_passes2,
19991999
],
20002000
',',
20012001
),
@@ -2032,7 +2032,7 @@ function compile_mlir!(
20322032
"remove-unnecessary-enzyme-ops",
20332033
"enzyme-simplify-math",
20342034
legalize_chlo_to_stablehlo...,
2035-
opt_passes3,
2035+
opt_passes_with_triton,
20362036
lower_enzymexla_linalg_pass,
20372037
jit,
20382038
]
@@ -2045,7 +2045,7 @@ function compile_mlir!(
20452045
"remove-unnecessary-enzyme-ops",
20462046
"enzyme-simplify-math",
20472047
legalize_chlo_to_stablehlo...,
2048-
opt_passes3,
2048+
opt_passes_with_triton,
20492049
kern,
20502050
raise_passes,
20512051
lower_enzymexla_linalg_pass,
@@ -2256,7 +2256,7 @@ function compile_mlir!(
22562256
run_pass_pipeline!(
22572257
mod,
22582258
join(
2259-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
2259+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
22602260
",",
22612261
),
22622262
"mid_pad_opts",

src/Ops.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,9 @@ function triton_call(
18471847
grid_x::TracedRNumber{<:Integer},
18481848
grid_y::TracedRNumber{<:Integer},
18491849
grid_z::TracedRNumber{<:Integer},
1850-
shmem::TracedRNumber{<:Integer},
1850+
block_x::TracedRNumber{<:Integer},
1851+
block_y::TracedRNumber{<:Integer},
1852+
block_z::TracedRNumber{<:Integer},
18511853
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18521854
# TODO: other kwargs
18531855
)
@@ -1857,7 +1859,9 @@ function triton_call(
18571859
grid_x.mlir_data,
18581860
grid_y.mlir_data,
18591861
grid_z.mlir_data,
1860-
shmem.mlir_data,
1862+
block_x.mlir_data,
1863+
block_y.mlir_data,
1864+
block_z.mlir_data,
18611865
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18621866
fn=symref,
18631867
result_0=MLIR.IR.Type[],

src/mlir/Dialects/TritonExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function call(
1717
gridx::Value,
1818
gridy::Value,
1919
gridz::Value,
20-
shmem::Value,
20+
blockx::Value,
21+
blocky::Value,
22+
blockz::Value,
2123
inputs::Vector{Value};
2224
result_0::Vector{IR.Type},
2325
fn,
@@ -31,7 +33,7 @@ function call(
3133
location=Location(),
3234
)
3335
op_ty_results = IR.Type[result_0...,]
34-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
36+
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...]
3537
owned_regions = Region[]
3638
successors = Block[]
3739
attributes = NamedAttribute[namedattribute("fn", fn),]

0 commit comments

Comments
 (0)