Skip to content

Commit 8ca7324

Browse files
committed
feat: tracing fully functional
1 parent ef82c3d commit 8ca7324

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,19 @@ function overlayed_pycall_with_triton(
106106
),
107107
)
108108

109+
# Currently we are doing a double compilation here. can we do better?
110+
# we are compiling here + lowering again inside enzymejax
109111
ccinfo = triton.compile(src; target=target, options=options.__dict__)
110112

111-
@show ccinfo.metadata
112-
@show ccinfo.asm.keys()
113-
# shared = ccinfo.metadata["shared"]
114-
kernel_name = pyconvert(String, ccinfo.metadata.name)
115-
# cluster_dims = ccinfo.metadata["cluster_dims"]
116-
117-
# println(pyconvert(String, ccinfo.asm["source"]))
118-
# println(pyconvert(String, ccinfo.asm["ttir"]))
119-
120-
res = @opcall triton_call(
121-
pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name
113+
@opcall triton_call(
114+
pyconvert(String, ccinfo.asm["ttir"]),
115+
filter(x -> x isa Reactant.TracedType, args)...;
116+
func_name=pyconvert(String, ccinfo.metadata.name),
117+
grid_x=@opcall(constant(grid[1])),
118+
grid_y=@opcall(constant(grid[2])),
119+
grid_z=@opcall(constant(grid[3])),
120+
shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))),
122121
)
123122

124-
return error("TODO: implement triton")
123+
return nothing
125124
end

src/Ops.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,14 +1763,27 @@ function triton_call(
17631763
mlir_code::String,
17641764
args::Union{TracedRArray,TracedRNumber,Number}...;
17651765
func_name::String="main",
1766+
grid_x::TracedRNumber{<:Integer},
1767+
grid_y::TracedRNumber{<:Integer},
1768+
grid_z::TracedRNumber{<:Integer},
1769+
shmem::TracedRNumber{<:Integer},
17661770
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
1771+
# TODO: other kwargs
17671772
)
17681773
name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
17691774

1770-
@show name_to_call
1771-
display(MLIR.IR.mmodule())
1775+
enzymexla.triton_call(
1776+
grid_x.mlir_data,
1777+
grid_y.mlir_data,
1778+
grid_z.mlir_data,
1779+
shmem.mlir_data,
1780+
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1781+
fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1782+
result_0=MLIR.IR.Type[],
1783+
location,
1784+
)
17721785

1773-
error("TODO: implement triton_call")
1786+
return nothing
17741787
end
17751788

17761789
"""

0 commit comments

Comments
 (0)