Skip to content

Commit 4a9a1ce

Browse files
committed
feat: return values
1 parent bbe23a0 commit 4a9a1ce

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function overlayed_pycall_with_triton(
120120
# we are compiling here + lowering again inside enzymejax
121121
ccinfo = triton.compile(src; target=target, options=options.__dict__)
122122

123-
@opcall triton_call(
123+
return @opcall triton_call(
124124
pyconvert(String, ccinfo.asm["source"]),
125125
filter(x -> x isa Reactant.TracedType, args)...;
126126
func_name=pyconvert(String, ccinfo.metadata.name),
@@ -131,6 +131,4 @@ function overlayed_pycall_with_triton(
131131
block_y=@opcall(constant(blocks[2])),
132132
block_z=@opcall(constant(blocks[3])),
133133
)
134-
135-
return nothing
136134
end

src/Ops.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,11 +1806,28 @@ function triton_call(
18061806
block_y::TracedRNumber{<:Integer},
18071807
block_z::TracedRNumber{<:Integer},
18081808
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
1809-
# TODO: other kwargs
18101809
)
18111810
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
18121811

1813-
triton_ext.call(
1812+
result_types = MLIR.IR.Type[]
1813+
output_operand_aliases = MLIR.IR.Attribute[]
1814+
output_to_arg = Int[]
1815+
for (i, arg) in enumerate(args)
1816+
if arg isa TracedRArray
1817+
push!(result_types, mlir_type(typeof(arg), size(arg)))
1818+
push!(
1819+
output_operand_aliases,
1820+
MLIR.IR.Attribute(
1821+
MLIR.API.stablehloOutputOperandAliasGet(
1822+
MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL
1823+
),
1824+
),
1825+
)
1826+
push!(output_to_arg, i)
1827+
end
1828+
end
1829+
1830+
results = triton_ext.call(
18141831
grid_x.mlir_data,
18151832
grid_y.mlir_data,
18161833
grid_z.mlir_data,
@@ -1819,11 +1836,23 @@ function triton_call(
18191836
block_z.mlir_data,
18201837
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18211838
fn=symref,
1822-
result_0=MLIR.IR.Type[],
1839+
result_0=result_types,
18231840
location,
1841+
output_operand_aliases,
18241842
)
18251843

1826-
return nothing
1844+
array_results = ()
1845+
for i in 1:MLIR.IR.nresults(results)
1846+
arg = args[output_to_arg[i]]
1847+
array_results = (
1848+
array_results...,
1849+
Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}(
1850+
(), MLIR.IR.result(results, i), size(arg)
1851+
),
1852+
)
1853+
end
1854+
length(array_results) == 1 && return array_results[1]
1855+
return array_results
18271856
end
18281857

18291858
"""

0 commit comments

Comments
 (0)