@@ -1743,68 +1743,38 @@ end
17431743end
17441744
17451745# Generate a unique name given a module hash and a function name.
1746- function _hlo_call_name (orig_name, module_suffix)
1747- return orig_name * " _hlo_call_" * module_suffix
1748- end
1746+ _new_function_name (orig_name, module_suffix) = orig_name * " _call_" * module_suffix
17491747
1750- """
1751- hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1752-
1753- Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1754- with the provided arguments and return a tuple for each result of the call.
1755-
1756- ```julia-repl
1757- julia> Reactant.@jit(
1758- hlo_call(
1759- \"\"\"
1760- module {
1761- func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1762- %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1763- return %0 : tensor<3xf32>
1764- }
1765- }
1766- \"\"\" ,
1767- Reactant.to_rarray(Float32[1, 2, 3]),
1768- Reactant.to_rarray(Float32[1, 2, 3]),
1769- )
1770- )
1771- (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1772- ```
1773- """
1774- @noinline function hlo_call (
1775- code,
1776- args... ;
1777- func_name= " main" ,
1778- location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1748+ function _extract_function (
1749+ code:: String ; func_name:: String = " main" , func_op_kind:: String = " func.func"
17791750)
17801751 module_suffix = string (hash (code); base= 16 )
1781- name_to_call = _hlo_call_name (func_name, module_suffix)
1752+ name_to_call = _new_function_name (func_name, module_suffix)
17821753
17831754 current_module = MLIR. IR. mmodule ()
17841755 top_level_block = MLIR. IR. body (current_module)
17851756
17861757 symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1787-
17881758 fn = MLIR. IR. lookup (
17891759 MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
17901760 )
1761+
17911762 if isnothing (fn)
17921763 new_mod = parse (MLIR. IR. Module, code)
17931764 new_mod_op = MLIR. IR. Operation (new_mod)
17941765 body = MLIR. IR. body (new_mod)
17951766
17961767 operations = collect (MLIR. IR. OperationIterator (body))
17971768 for op in operations
1798- if MLIR. IR. name (op) == " func.func "
1769+ if MLIR. IR. name (op) == func_op_kind
17991770 fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
18001771 if fn_name == func_name
18011772 fn = op
18021773 end
18031774
1804- new_name = _hlo_call_name (fn_name, module_suffix)
18051775 res = MLIR. IR. LogicalResult (
18061776 MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1807- fn_name, new_name , new_mod_op
1777+ fn_name, name_to_call , new_mod_op
18081778 ),
18091779 )
18101780 @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
@@ -1817,7 +1787,7 @@ julia> Reactant.@jit(
18171787 )
18181788
18191789 # Change function name
1820- MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (new_name ))
1790+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call ))
18211791 end
18221792 end
18231793
@@ -1831,11 +1801,59 @@ julia> Reactant.@jit(
18311801 error (" hlo_call: could not find function $func_name in the provided module" )
18321802 end
18331803
1804+ return name_to_call
1805+ end
1806+
1807+ function triton_call (
1808+ mlir_code:: String ,
1809+ args:: Union{TracedRArray,TracedRNumber,Number} ...;
1810+ func_name:: String = " main" ,
1811+ location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
1812+ )
1813+ name_to_call = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" )
1814+
1815+ @show name_to_call
1816+ display (MLIR. IR. mmodule ())
1817+
1818+ error (" TODO: implement triton_call" )
1819+ end
1820+
1821+ """
1822+ hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1823+
1824+ Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1825+ with the provided arguments and return a tuple for each result of the call.
1826+
1827+ ```julia-repl
1828+ julia> Reactant.@jit(
1829+ hlo_call(
1830+ \"\"\"
1831+ module {
1832+ func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1833+ %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1834+ return %0 : tensor<3xf32>
1835+ }
1836+ }
1837+ \"\"\" ,
1838+ Reactant.to_rarray(Float32[1, 2, 3]),
1839+ Reactant.to_rarray(Float32[1, 2, 3]),
1840+ )
1841+ )
1842+ (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1843+ ```
1844+ """
1845+ @noinline function hlo_call (
1846+ code,
1847+ args:: Union{TracedRArray,TracedRNumber} ...;
1848+ func_name= " main" ,
1849+ location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1850+ )
1851+ name_to_call = _extract_function (code; func_name, func_op_kind= " func.func" )
1852+
18341853 ftype_attr = MLIR. IR. attr (fn, " function_type" )
18351854 ftype = MLIR. IR. Type (ftype_attr)
18361855
1837- @assert all (Base. Fix2 (isa, Union{TracedRArray,TracedRNumber}), args) " hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
1838- @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name "
1856+ @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name . Expected $(MLIR. IR. ninputs (ftype)) , got $(length (args)) "
18391857
18401858 for (i, arg) in enumerate (args)
18411859 expected_type = MLIR. IR. input (ftype, i)
0 commit comments