Skip to content

Commit bbe23a0

Browse files
committed
fix: new API
1 parent 747ab73 commit bbe23a0

File tree

5 files changed

+43
-29
lines changed

5 files changed

+43
-29
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "e408511ec376befe19bd48a0e725732b322fce3b"
7+
ENZYMEXLA_COMMIT = "42be7523965e87da6b8cec3ceec5e1d2002a594f"
88

99
ENZYMEXLA_SHA256 = ""
1010

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
# TODO: support using metaparams here
51-
normalize_grid(grid::Integer) = normalize_grid((grid,))
52-
function normalize_grid(grid::Dims{N}) where {N}
50+
normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,))
51+
function normalize_grid_and_blocks(grid::Dims{N}) where {N}
5352
@assert N <= 3
5453
@assert all(grid .> 0)
5554
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))")
6261

6362
# TODO: better name for hints?
6463
function overlayed_pycall_with_triton(
65-
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
64+
kernel::Py,
65+
args...;
66+
grid,
67+
blocks,
68+
num_warps::Integer=1,
69+
num_stages::Integer=3,
70+
hints=nothing,
6671
)
6772
triton = tritonptr[]
6873

69-
grid = normalize_grid(grid)
74+
grid = normalize_grid_and_blocks(grid)
75+
blocks = normalize_grid_and_blocks(blocks)
7076

7177
mapped = map(signature_string, args)
7278
signature = first.(mapped)
@@ -121,7 +127,9 @@ function overlayed_pycall_with_triton(
121127
grid_x=@opcall(constant(grid[1])),
122128
grid_y=@opcall(constant(grid[2])),
123129
grid_z=@opcall(constant(grid[3])),
124-
shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))),
130+
block_x=@opcall(constant(blocks[1])),
131+
block_y=@opcall(constant(blocks[2])),
132+
block_z=@opcall(constant(blocks[3])),
125133
)
126134

127135
return nothing

src/Compiler.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,7 +1781,7 @@ function compile_mlir!(
17811781
backend,
17821782
enable_triton_passes=false,
17831783
)
1784-
opt_passes3 = optimization_passes(
1784+
opt_passes_with_triton = optimization_passes(
17851785
compile_options;
17861786
sroa=false,
17871787
recognize_comms,
@@ -1841,12 +1841,12 @@ function compile_mlir!(
18411841
"enzyme-batch",
18421842
opt_passes2,
18431843
enzyme_pass,
1844-
opt_passes3,
1844+
opt_passes_with_triton,
18451845
"canonicalize",
18461846
"remove-unnecessary-enzyme-ops",
18471847
"enzyme-simplify-math",
18481848
legalize_chlo_to_stablehlo...,
1849-
opt_passes3,
1849+
opt_passes2,
18501850
lower_enzymexla_linalg_pass,
18511851
jit,
18521852
]
@@ -1857,12 +1857,12 @@ function compile_mlir!(
18571857
"enzyme-batch",
18581858
opt_passes2,
18591859
enzyme_pass,
1860-
opt_passes3,
1860+
opt_passes_with_triton,
18611861
"canonicalize",
18621862
"remove-unnecessary-enzyme-ops",
18631863
"enzyme-simplify-math",
18641864
legalize_chlo_to_stablehlo...,
1865-
opt_passes3,
1865+
opt_passes2,
18661866
kern,
18671867
raise_passes,
18681868
lower_enzymexla_linalg_pass,
@@ -1886,12 +1886,12 @@ function compile_mlir!(
18861886
"enzyme-batch",
18871887
opt_passes2,
18881888
enzyme_pass,
1889-
opt_passes3,
1889+
opt_passes_with_triton,
18901890
"canonicalize",
18911891
"remove-unnecessary-enzyme-ops",
18921892
"enzyme-simplify-math",
18931893
legalize_chlo_to_stablehlo...,
1894-
opt_passes3,
1894+
opt_passes2,
18951895
]
18961896
end,
18971897
',',
@@ -1911,12 +1911,12 @@ function compile_mlir!(
19111911
"enzyme-batch",
19121912
opt_passes2,
19131913
enzyme_pass,
1914-
opt_passes3,
1914+
opt_passes_with_triton,
19151915
"canonicalize",
19161916
"remove-unnecessary-enzyme-ops",
19171917
"enzyme-simplify-math",
19181918
legalize_chlo_to_stablehlo...,
1919-
opt_passes3,
1919+
opt_passes2,
19201920
]
19211921
else
19221922
[
@@ -1925,12 +1925,12 @@ function compile_mlir!(
19251925
"enzyme-batch",
19261926
opt_passes2,
19271927
enzyme_pass,
1928-
opt_passes3,
1928+
opt_passes_with_triton,
19291929
"canonicalize",
19301930
"remove-unnecessary-enzyme-ops",
19311931
"enzyme-simplify-math",
19321932
legalize_chlo_to_stablehlo...,
1933-
opt_passes3,
1933+
opt_passes2,
19341934
kern,
19351935
raise_passes,
19361936
]
@@ -1952,12 +1952,12 @@ function compile_mlir!(
19521952
"enzyme-batch",
19531953
opt_passes2,
19541954
enzyme_pass,
1955-
opt_passes3,
1955+
opt_passes_with_triton,
19561956
"canonicalize",
19571957
"remove-unnecessary-enzyme-ops",
19581958
"enzyme-simplify-math",
19591959
legalize_chlo_to_stablehlo...,
1960-
opt_passes3,
1960+
opt_passes2,
19611961
kern,
19621962
]
19631963
end,
@@ -1975,12 +1975,12 @@ function compile_mlir!(
19751975
"enzyme-batch",
19761976
opt_passes2,
19771977
enzyme_pass,
1978-
opt_passes3,
1978+
opt_passes_with_triton,
19791979
"canonicalize",
19801980
"remove-unnecessary-enzyme-ops",
19811981
"enzyme-simplify-math",
19821982
legalize_chlo_to_stablehlo...,
1983-
opt_passes3,
1983+
opt_passes2,
19841984
],
19851985
',',
19861986
),
@@ -2017,7 +2017,7 @@ function compile_mlir!(
20172017
"remove-unnecessary-enzyme-ops",
20182018
"enzyme-simplify-math",
20192019
legalize_chlo_to_stablehlo...,
2020-
opt_passes3,
2020+
opt_passes_with_triton,
20212021
lower_enzymexla_linalg_pass,
20222022
jit,
20232023
]
@@ -2030,7 +2030,7 @@ function compile_mlir!(
20302030
"remove-unnecessary-enzyme-ops",
20312031
"enzyme-simplify-math",
20322032
legalize_chlo_to_stablehlo...,
2033-
opt_passes3,
2033+
opt_passes_with_triton,
20342034
kern,
20352035
raise_passes,
20362036
lower_enzymexla_linalg_pass,
@@ -2241,7 +2241,7 @@ function compile_mlir!(
22412241
run_pass_pipeline!(
22422242
mod,
22432243
join(
2244-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3],
2244+
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
22452245
",",
22462246
),
22472247
"mid_pad_opts",

src/Ops.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,7 +1802,9 @@ function triton_call(
18021802
grid_x::TracedRNumber{<:Integer},
18031803
grid_y::TracedRNumber{<:Integer},
18041804
grid_z::TracedRNumber{<:Integer},
1805-
shmem::TracedRNumber{<:Integer},
1805+
block_x::TracedRNumber{<:Integer},
1806+
block_y::TracedRNumber{<:Integer},
1807+
block_z::TracedRNumber{<:Integer},
18061808
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18071809
# TODO: other kwargs
18081810
)
@@ -1812,7 +1814,9 @@ function triton_call(
18121814
grid_x.mlir_data,
18131815
grid_y.mlir_data,
18141816
grid_z.mlir_data,
1815-
shmem.mlir_data,
1817+
block_x.mlir_data,
1818+
block_y.mlir_data,
1819+
block_z.mlir_data,
18161820
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18171821
fn=symref,
18181822
result_0=MLIR.IR.Type[],

src/mlir/Dialects/TritonExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function call(
1717
gridx::Value,
1818
gridy::Value,
1919
gridz::Value,
20-
shmem::Value,
20+
blockx::Value,
21+
blocky::Value,
22+
blockz::Value,
2123
inputs::Vector{Value};
2224
result_0::Vector{IR.Type},
2325
fn,
@@ -31,7 +33,7 @@ function call(
3133
location=Location(),
3234
)
3335
op_ty_results = IR.Type[result_0...,]
34-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
36+
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...]
3537
owned_regions = Region[]
3638
successors = Block[]
3739
attributes = NamedAttribute[namedattribute("fn", fn),]

0 commit comments

Comments
 (0)