Skip to content

Commit 7ad8b4c

Browse files
committed
fix: new API
1 parent b4dc832 commit 7ad8b4c

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
# TODO: support using metaparams here
51-
normalize_grid(grid::Integer) = normalize_grid((grid,))
52-
function normalize_grid(grid::Dims{N}) where {N}
50+
normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
51+
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
5352
@assert N <= 3
5453
@assert all(grid .> 0)
5554
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6261

6362
# TODO: better name for hints?
6463
function overlayed_pycall_with_triton(
65-
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
64+
kernel::Py,
65+
args...;
66+
grid,
67+
blocks,
68+
num_warps::Integer=1,
69+
num_stages::Integer=3,
70+
hints=nothing,
6671
)
6772
triton = tritonptr[]
6873

69-
grid = normalize_grid(grid)
74+
grid = normalize_grid_and_blocks(grid)
75+
blocks = normalize_grid_and_blocks(blocks)
7076

7177
mapped = map(signature_string, args)
7278
signature = first.(mapped)
@@ -121,7 +127,9 @@ function overlayed_pycall_with_triton(
121127
grid_x=@opcall(constant(grid[1])),
122128
grid_y=@opcall(constant(grid[2])),
123129
grid_z=@opcall(constant(grid[3])),
124-
shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))),
130+
block_x=@opcall(constant(blocks[1])),
131+
block_y=@opcall(constant(blocks[2])),
132+
block_z=@opcall(constant(blocks[3])),
125133
)
126134

127135
return nothing

src/Compiler.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@ function compile_mlir!(
17911791
backend,
17921792
enable_triton_passes=false,
17931793
)
1794-
opt_passes3 = optimization_passes(
1794+
opt_passes_with_triton = optimization_passes(
17951795
compile_options;
17961796
sroa=false,
17971797
recognize_comms,
@@ -1851,12 +1851,12 @@ function compile_mlir!(
18511851
"enzyme-batch",
18521852
opt_passes2,
18531853
enzyme_pass,
1854-
opt_passes3,
1854+
opt_passes_with_triton,
18551855
"canonicalize",
18561856
"remove-unnecessary-enzyme-ops",
18571857
"enzyme-simplify-math",
18581858
legalize_chlo_to_stablehlo...,
1859-
opt_passes3,
1859+
opt_passes2,
18601860
lower_enzymexla_linalg_pass,
18611861
jit,
18621862
]
@@ -1867,12 +1867,12 @@ function compile_mlir!(
18671867
"enzyme-batch",
18681868
opt_passes2,
18691869
enzyme_pass,
1870-
opt_passes3,
1870+
opt_passes_with_triton,
18711871
"canonicalize",
18721872
"remove-unnecessary-enzyme-ops",
18731873
"enzyme-simplify-math",
18741874
legalize_chlo_to_stablehlo...,
1875-
opt_passes3,
1875+
opt_passes2,
18761876
kern,
18771877
raise_passes,
18781878
lower_enzymexla_linalg_pass,
@@ -1896,12 +1896,12 @@ function compile_mlir!(
18961896
"enzyme-batch",
18971897
opt_passes2,
18981898
enzyme_pass,
1899-
opt_passes3,
1899+
opt_passes_with_triton,
19001900
"canonicalize",
19011901
"remove-unnecessary-enzyme-ops",
19021902
"enzyme-simplify-math",
19031903
legalize_chlo_to_stablehlo...,
1904-
opt_passes3,
1904+
opt_passes2,
19051905
]
19061906
end,
19071907
',',
@@ -1921,12 +1921,12 @@ function compile_mlir!(
19211921
"enzyme-batch",
19221922
opt_passes2,
19231923
enzyme_pass,
1924-
opt_passes3,
1924+
opt_passes_with_triton,
19251925
"canonicalize",
19261926
"remove-unnecessary-enzyme-ops",
19271927
"enzyme-simplify-math",
19281928
legalize_chlo_to_stablehlo...,
1929-
opt_passes3,
1929+
opt_passes2,
19301930
]
19311931
else
19321932
[
@@ -1935,12 +1935,12 @@ function compile_mlir!(
19351935
"enzyme-batch",
19361936
opt_passes2,
19371937
enzyme_pass,
1938-
opt_passes3,
1938+
opt_passes_with_triton,
19391939
"canonicalize",
19401940
"remove-unnecessary-enzyme-ops",
19411941
"enzyme-simplify-math",
19421942
legalize_chlo_to_stablehlo...,
1943-
opt_passes3,
1943+
opt_passes2,
19441944
kern,
19451945
raise_passes,
19461946
]
@@ -1962,12 +1962,12 @@ function compile_mlir!(
19621962
"enzyme-batch",
19631963
opt_passes2,
19641964
enzyme_pass,
1965-
opt_passes3,
1965+
opt_passes_with_triton,
19661966
"canonicalize",
19671967
"remove-unnecessary-enzyme-ops",
19681968
"enzyme-simplify-math",
19691969
legalize_chlo_to_stablehlo...,
1970-
opt_passes3,
1970+
opt_passes2,
19711971
kern,
19721972
]
19731973
end,
@@ -1985,12 +1985,12 @@ function compile_mlir!(
19851985
"enzyme-batch",
19861986
opt_passes2,
19871987
enzyme_pass,
1988-
opt_passes3,
1988+
opt_passes_with_triton,
19891989
"canonicalize",
19901990
"remove-unnecessary-enzyme-ops",
19911991
"enzyme-simplify-math",
19921992
legalize_chlo_to_stablehlo...,
1993-
opt_passes3,
1993+
opt_passes2,
19941994
],
19951995
',',
19961996
),
@@ -2027,7 +2027,7 @@ function compile_mlir!(
20272027
"remove-unnecessary-enzyme-ops",
20282028
"enzyme-simplify-math",
20292029
legalize_chlo_to_stablehlo...,
2030-
opt_passes3,
2030+
opt_passes_with_triton,
20312031
lower_enzymexla_linalg_pass,
20322032
jit,
20332033
]
@@ -2040,7 +2040,7 @@ function compile_mlir!(
20402040
"remove-unnecessary-enzyme-ops",
20412041
"enzyme-simplify-math",
20422042
legalize_chlo_to_stablehlo...,
2043-
opt_passes3,
2043+
opt_passes_with_triton,
20442044
kern,
20452045
raise_passes,
20462046
lower_enzymexla_linalg_pass,
@@ -2251,7 +2251,7 @@ function compile_mlir!(
22512251
run_pass_pipeline!(
22522252
mod,
22532253
join(
2254-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
2254+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
22552255
",",
22562256
),
22572257
"mid_pad_opts",

src/Ops.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,9 @@ function triton_call(
18471847
grid_x::TracedRNumber{<:Integer},
18481848
grid_y::TracedRNumber{<:Integer},
18491849
grid_z::TracedRNumber{<:Integer},
1850-
shmem::TracedRNumber{<:Integer},
1850+
block_x::TracedRNumber{<:Integer},
1851+
block_y::TracedRNumber{<:Integer},
1852+
block_z::TracedRNumber{<:Integer},
18511853
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18521854
# TODO: other kwargs
18531855
)
@@ -1857,7 +1859,9 @@ function triton_call(
18571859
grid_x.mlir_data,
18581860
grid_y.mlir_data,
18591861
grid_z.mlir_data,
1860-
shmem.mlir_data,
1862+
block_x.mlir_data,
1863+
block_y.mlir_data,
1864+
block_z.mlir_data,
18611865
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18621866
fn=symref,
18631867
result_0=MLIR.IR.Type[],

src/mlir/Dialects/TritonExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function call(
1717
gridx::Value,
1818
gridy::Value,
1919
gridz::Value,
20-
shmem::Value,
20+
blockx::Value,
21+
blocky::Value,
22+
blockz::Value,
2123
inputs::Vector{Value};
2224
result_0::Vector{IR.Type},
2325
fn,
@@ -31,7 +33,7 @@ function call(
3133
location=Location(),
3234
)
3335
op_ty_results = IR.Type[result_0...,]
34-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
36+
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...]
3537
owned_regions = Region[]
3638
successors = Block[]
3739
attributes = NamedAttribute[namedattribute("fn", fn),]

0 commit comments

Comments
 (0)