Skip to content

Commit 9f1cb47

Browse files
committed
fix: new API
1 parent e869448 commit 9f1cb47

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
# TODO: support using metaparams here
51-
normalize_grid(grid::Integer) = normalize_grid((grid,))
52-
function normalize_grid(grid::Dims{N}) where {N}
50+
normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
51+
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
5352
@assert N <= 3
5453
@assert all(grid .> 0)
5554
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6261

6362
# TODO: better name for hints?
6463
function overlayed_pycall_with_triton(
65-
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
64+
kernel::Py,
65+
args...;
66+
grid,
67+
blocks,
68+
num_warps::Integer=1,
69+
num_stages::Integer=3,
70+
hints=nothing,
6671
)
6772
triton = tritonptr[]
6873

69-
grid = normalize_grid(grid)
74+
grid = normalize_grid_and_blocks(grid)
75+
blocks = normalize_grid_and_blocks(blocks)
7076

7177
mapped = map(signature_string, args)
7278
signature = first.(mapped)
@@ -121,7 +127,9 @@ function overlayed_pycall_with_triton(
121127
grid_x=@opcall(constant(grid[1])),
122128
grid_y=@opcall(constant(grid[2])),
123129
grid_z=@opcall(constant(grid[3])),
124-
shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))),
130+
block_x=@opcall(constant(blocks[1])),
131+
block_y=@opcall(constant(blocks[2])),
132+
block_z=@opcall(constant(blocks[3])),
125133
)
126134

127135
return nothing

src/Compiler.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,7 @@ function compile_mlir!(
17891789
backend,
17901790
enable_triton_passes=false,
17911791
)
1792-
opt_passes3 = optimization_passes(
1792+
opt_passes_with_triton = optimization_passes(
17931793
compile_options;
17941794
sroa=false,
17951795
recognize_comms,
@@ -1849,12 +1849,12 @@ function compile_mlir!(
18491849
"enzyme-batch",
18501850
opt_passes2,
18511851
enzyme_pass,
1852-
opt_passes3,
1852+
opt_passes_with_triton,
18531853
"canonicalize",
18541854
"remove-unnecessary-enzyme-ops",
18551855
"enzyme-simplify-math",
18561856
legalize_chlo_to_stablehlo...,
1857-
opt_passes3,
1857+
opt_passes2,
18581858
lower_enzymexla_linalg_pass,
18591859
jit,
18601860
]
@@ -1865,12 +1865,12 @@ function compile_mlir!(
18651865
"enzyme-batch",
18661866
opt_passes2,
18671867
enzyme_pass,
1868-
opt_passes3,
1868+
opt_passes_with_triton,
18691869
"canonicalize",
18701870
"remove-unnecessary-enzyme-ops",
18711871
"enzyme-simplify-math",
18721872
legalize_chlo_to_stablehlo...,
1873-
opt_passes3,
1873+
opt_passes2,
18741874
kern,
18751875
raise_passes,
18761876
lower_enzymexla_linalg_pass,
@@ -1894,12 +1894,12 @@ function compile_mlir!(
18941894
"enzyme-batch",
18951895
opt_passes2,
18961896
enzyme_pass,
1897-
opt_passes3,
1897+
opt_passes_with_triton,
18981898
"canonicalize",
18991899
"remove-unnecessary-enzyme-ops",
19001900
"enzyme-simplify-math",
19011901
legalize_chlo_to_stablehlo...,
1902-
opt_passes3,
1902+
opt_passes2,
19031903
]
19041904
end,
19051905
',',
@@ -1919,12 +1919,12 @@ function compile_mlir!(
19191919
"enzyme-batch",
19201920
opt_passes2,
19211921
enzyme_pass,
1922-
opt_passes3,
1922+
opt_passes_with_triton,
19231923
"canonicalize",
19241924
"remove-unnecessary-enzyme-ops",
19251925
"enzyme-simplify-math",
19261926
legalize_chlo_to_stablehlo...,
1927-
opt_passes3,
1927+
opt_passes2,
19281928
]
19291929
else
19301930
[
@@ -1933,12 +1933,12 @@ function compile_mlir!(
19331933
"enzyme-batch",
19341934
opt_passes2,
19351935
enzyme_pass,
1936-
opt_passes3,
1936+
opt_passes_with_triton,
19371937
"canonicalize",
19381938
"remove-unnecessary-enzyme-ops",
19391939
"enzyme-simplify-math",
19401940
legalize_chlo_to_stablehlo...,
1941-
opt_passes3,
1941+
opt_passes2,
19421942
kern,
19431943
raise_passes,
19441944
]
@@ -1960,12 +1960,12 @@ function compile_mlir!(
19601960
"enzyme-batch",
19611961
opt_passes2,
19621962
enzyme_pass,
1963-
opt_passes3,
1963+
opt_passes_with_triton,
19641964
"canonicalize",
19651965
"remove-unnecessary-enzyme-ops",
19661966
"enzyme-simplify-math",
19671967
legalize_chlo_to_stablehlo...,
1968-
opt_passes3,
1968+
opt_passes2,
19691969
kern,
19701970
]
19711971
end,
@@ -1983,12 +1983,12 @@ function compile_mlir!(
19831983
"enzyme-batch",
19841984
opt_passes2,
19851985
enzyme_pass,
1986-
opt_passes3,
1986+
opt_passes_with_triton,
19871987
"canonicalize",
19881988
"remove-unnecessary-enzyme-ops",
19891989
"enzyme-simplify-math",
19901990
legalize_chlo_to_stablehlo...,
1991-
opt_passes3,
1991+
opt_passes2,
19921992
],
19931993
',',
19941994
),
@@ -2025,7 +2025,7 @@ function compile_mlir!(
20252025
"remove-unnecessary-enzyme-ops",
20262026
"enzyme-simplify-math",
20272027
legalize_chlo_to_stablehlo...,
2028-
opt_passes3,
2028+
opt_passes_with_triton,
20292029
lower_enzymexla_linalg_pass,
20302030
jit,
20312031
]
@@ -2038,7 +2038,7 @@ function compile_mlir!(
20382038
"remove-unnecessary-enzyme-ops",
20392039
"enzyme-simplify-math",
20402040
legalize_chlo_to_stablehlo...,
2041-
opt_passes3,
2041+
opt_passes_with_triton,
20422042
kern,
20432043
raise_passes,
20442044
lower_enzymexla_linalg_pass,
@@ -2249,7 +2249,7 @@ function compile_mlir!(
22492249
run_pass_pipeline!(
22502250
mod,
22512251
join(
2252-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
2252+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
22532253
",",
22542254
),
22552255
"mid_pad_opts",

src/Ops.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,9 @@ function triton_call(
18471847
grid_x::TracedRNumber{<:Integer},
18481848
grid_y::TracedRNumber{<:Integer},
18491849
grid_z::TracedRNumber{<:Integer},
1850-
shmem::TracedRNumber{<:Integer},
1850+
block_x::TracedRNumber{<:Integer},
1851+
block_y::TracedRNumber{<:Integer},
1852+
block_z::TracedRNumber{<:Integer},
18511853
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18521854
# TODO: other kwargs
18531855
)
@@ -1857,7 +1859,9 @@ function triton_call(
18571859
grid_x.mlir_data,
18581860
grid_y.mlir_data,
18591861
grid_z.mlir_data,
1860-
shmem.mlir_data,
1862+
block_x.mlir_data,
1863+
block_y.mlir_data,
1864+
block_z.mlir_data,
18611865
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18621866
fn=symref,
18631867
result_0=MLIR.IR.Type[],

src/mlir/Dialects/TritonExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function call(
1717
gridx::Value,
1818
gridy::Value,
1919
gridz::Value,
20-
shmem::Value,
20+
blockx::Value,
21+
blocky::Value,
22+
blockz::Value,
2123
inputs::Vector{Value};
2224
result_0::Vector{IR.Type},
2325
fn,
@@ -31,7 +33,7 @@ function call(
3133
location=Location(),
3234
)
3335
op_ty_results = IR.Type[result_0...,]
34-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
36+
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...]
3537
owned_regions = Region[]
3638
successors = Block[]
3739
attributes = NamedAttribute[namedattribute("fn", fn),]

0 commit comments

Comments
 (0)