Skip to content

Commit 7e213c3

Browse files
committed
feat: return values
1 parent 7ad8b4c commit 7e213c3

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function overlayed_pycall_with_triton(
120120
# we are compiling here + lowering again inside enzymejax
121121
ccinfo = triton.compile(src; target=target, options=options.__dict__)
122122

123-
@opcall triton_call(
123+
return @opcall triton_call(
124124
pyconvert(String, ccinfo.asm["source"]),
125125
filter(x -> x isa Reactant.TracedType, args)...;
126126
func_name=pyconvert(String, ccinfo.metadata.name),
@@ -131,6 +131,4 @@ function overlayed_pycall_with_triton(
131131
block_y=@opcall(constant(blocks[2])),
132132
block_z=@opcall(constant(blocks[3])),
133133
)
134-
135-
return nothing
136134
end

src/Ops.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,11 +1851,28 @@ function triton_call(
18511851
block_y::TracedRNumber{<:Integer},
18521852
block_z::TracedRNumber{<:Integer},
18531853
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
1854-
# TODO: other kwargs
18551854
)
18561855
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
18571856

1858-
triton_ext.call(
1857+
result_types = MLIR.IR.Type[]
1858+
output_operand_aliases = MLIR.IR.Attribute[]
1859+
output_to_arg = Int[]
1860+
for (i, arg) in enumerate(args)
1861+
if arg isa TracedRArray
1862+
push!(result_types, mlir_type(typeof(arg), size(arg)))
1863+
push!(
1864+
output_operand_aliases,
1865+
MLIR.IR.Attribute(
1866+
MLIR.API.stablehloOutputOperandAliasGet(
1867+
MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL
1868+
),
1869+
),
1870+
)
1871+
push!(output_to_arg, i)
1872+
end
1873+
end
1874+
1875+
results = triton_ext.call(
18591876
grid_x.mlir_data,
18601877
grid_y.mlir_data,
18611878
grid_z.mlir_data,
@@ -1864,11 +1881,23 @@ function triton_call(
18641881
block_z.mlir_data,
18651882
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18661883
fn=symref,
1867-
result_0=MLIR.IR.Type[],
1884+
result_0=result_types,
18681885
location,
1886+
output_operand_aliases,
18691887
)
18701888

1871-
return nothing
1889+
array_results = ()
1890+
for i in 1:MLIR.IR.nresults(results)
1891+
arg = args[output_to_arg[i]]
1892+
array_results = (
1893+
array_results...,
1894+
Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}(
1895+
(), MLIR.IR.result(results, i), size(arg)
1896+
),
1897+
)
1898+
end
1899+
length(array_results) == 1 && return array_results[1]
1900+
return array_results
18721901
end
18731902

18741903
"""

0 commit comments

Comments
 (0)