@@ -1806,11 +1806,28 @@ function triton_call(
18061806 block_y:: TracedRNumber{<:Integer} ,
18071807 block_z:: TracedRNumber{<:Integer} ,
18081808 location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
1809- # TODO : other kwargs
18101809)
18111810 _, symref = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" , location)
18121811
1813- triton_ext. call (
1812+ result_types = MLIR. IR. Type[]
1813+ output_operand_aliases = MLIR. IR. Attribute[]
1814+ output_to_arg = Int[]
1815+ for (i, arg) in enumerate (args)
1816+ if arg isa TracedRArray
1817+ push! (result_types, mlir_type (typeof (arg), size (arg)))
1818+ push! (
1819+ output_operand_aliases,
1820+ MLIR. IR. Attribute (
1821+ MLIR. API. stablehloOutputOperandAliasGet (
1822+ MLIR. IR. context (), 0 , C_NULL , Int64 (i - 1 ), 0 , C_NULL
1823+ ),
1824+ ),
1825+ )
1826+ push! (output_to_arg, i)
1827+ end
1828+ end
1829+
1830+ results = triton_ext. call (
18141831 grid_x. mlir_data,
18151832 grid_y. mlir_data,
18161833 grid_z. mlir_data,
@@ -1819,11 +1836,23 @@ function triton_call(
18191836 block_z. mlir_data,
18201837 [Reactant. TracedUtils. get_mlir_data (a) for a in args];
18211838 fn= symref,
1822- result_0= MLIR . IR . Type[] ,
1839+ result_0= result_types ,
18231840 location,
1841+ output_operand_aliases,
18241842 )
18251843
1826- return nothing
1844+ array_results = ()
1845+ for i in 1 : MLIR. IR. nresults (results)
1846+ arg = args[output_to_arg[i]]
1847+ array_results = (
1848+ array_results... ,
1849+ Reactant. TracedRArray {unwrapped_eltype(arg),ndims(arg)} (
1850+ (), MLIR. IR. result (results, i), size (arg)
1851+ ),
1852+ )
1853+ end
1854+ length (array_results) == 1 && return array_results[1 ]
1855+ return array_results
18271856end
18281857
18291858"""
0 commit comments