@@ -1851,11 +1851,28 @@ function triton_call(
18511851 block_y:: TracedRNumber{<:Integer} ,
18521852 block_z:: TracedRNumber{<:Integer} ,
18531853 location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
1854- # TODO : other kwargs
18551854)
18561855 _, symref = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" , location)
18571856
1858- triton_ext. call (
1857+ result_types = MLIR. IR. Type[]
1858+ output_operand_aliases = MLIR. IR. Attribute[]
1859+ output_to_arg = Int[]
1860+ for (i, arg) in enumerate (args)
1861+ if arg isa TracedRArray
1862+ push! (result_types, mlir_type (typeof (arg), size (arg)))
1863+ push! (
1864+ output_operand_aliases,
1865+ MLIR. IR. Attribute (
1866+ MLIR. API. stablehloOutputOperandAliasGet (
1867+ MLIR. IR. context (), 0 , C_NULL , Int64 (i - 1 ), 0 , C_NULL
1868+ ),
1869+ ),
1870+ )
1871+ push! (output_to_arg, i)
1872+ end
1873+ end
1874+
1875+ results = triton_ext. call (
18591876 grid_x. mlir_data,
18601877 grid_y. mlir_data,
18611878 grid_z. mlir_data,
@@ -1864,11 +1881,23 @@ function triton_call(
18641881 block_z. mlir_data,
18651882 [Reactant. TracedUtils. get_mlir_data (a) for a in args];
18661883 fn= symref,
1867- result_0= MLIR . IR . Type[] ,
1884+ result_0= result_types ,
18681885 location,
1886+ output_operand_aliases,
18691887 )
18701888
1871- return nothing
1889+ array_results = ()
1890+ for i in 1 : MLIR. IR. nresults (results)
1891+ arg = args[output_to_arg[i]]
1892+ array_results = (
1893+ array_results... ,
1894+ Reactant. TracedRArray {unwrapped_eltype(arg),ndims(arg)} (
1895+ (), MLIR. IR. result (results, i), size (arg)
1896+ ),
1897+ )
1898+ end
1899+ length (array_results) == 1 && return array_results[1 ]
1900+ return array_results
18721901end
18731902
18741903"""
0 commit comments