@@ -1698,68 +1698,38 @@ end
16981698end
16991699
17001700# Generate a unique name given a module hash and a function name.
1701- function _hlo_call_name (orig_name, module_suffix)
1702- return orig_name * " _hlo_call_" * module_suffix
1703- end
1701+ _new_function_name (orig_name, module_suffix) = orig_name * " _call_" * module_suffix
17041702
1705- """
1706- hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1707-
1708- Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1709- with the provided arguments and return a tuple for each result of the call.
1710-
1711- ```julia-repl
1712- julia> Reactant.@jit(
1713- hlo_call(
1714- \"\"\"
1715- module {
1716- func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1717- %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1718- return %0 : tensor<3xf32>
1719- }
1720- }
1721- \"\"\" ,
1722- Reactant.to_rarray(Float32[1, 2, 3]),
1723- Reactant.to_rarray(Float32[1, 2, 3]),
1724- )
1725- )
1726- (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1727- ```
1728- """
1729- @noinline function hlo_call (
1730- code,
1731- args... ;
1732- func_name= " main" ,
1733- location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1703+ function _extract_function (
1704+ code:: String ; func_name:: String = " main" , func_op_kind:: String = " func.func"
17341705)
17351706 module_suffix = string (hash (code); base= 16 )
1736- name_to_call = _hlo_call_name (func_name, module_suffix)
1707+ name_to_call = _new_function_name (func_name, module_suffix)
17371708
17381709 current_module = MLIR. IR. mmodule ()
17391710 top_level_block = MLIR. IR. body (current_module)
17401711
17411712 symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1742-
17431713 fn = MLIR. IR. lookup (
17441714 MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
17451715 )
1716+
17461717 if isnothing (fn)
17471718 new_mod = parse (MLIR. IR. Module, code)
17481719 new_mod_op = MLIR. IR. Operation (new_mod)
17491720 body = MLIR. IR. body (new_mod)
17501721
17511722 operations = collect (MLIR. IR. OperationIterator (body))
17521723 for op in operations
1753- if MLIR. IR. name (op) == " func.func "
1724+ if MLIR. IR. name (op) == func_op_kind
17541725 fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
17551726 if fn_name == func_name
17561727 fn = op
17571728 end
17581729
1759- new_name = _hlo_call_name (fn_name, module_suffix)
17601730 res = MLIR. IR. LogicalResult (
17611731 MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1762- fn_name, new_name , new_mod_op
1732+ fn_name, name_to_call , new_mod_op
17631733 ),
17641734 )
17651735 @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
@@ -1772,7 +1742,7 @@ julia> Reactant.@jit(
17721742 )
17731743
17741744 # Change function name
1775- MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (new_name ))
1745+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call ))
17761746 end
17771747 end
17781748
@@ -1786,11 +1756,59 @@ julia> Reactant.@jit(
17861756 error (" hlo_call: could not find function $func_name in the provided module" )
17871757 end
17881758
1759+ return name_to_call
1760+ end
1761+
1762+ function triton_call (
1763+ mlir_code:: String ,
1764+ args:: Union{TracedRArray,TracedRNumber,Number} ...;
1765+ func_name:: String = " main" ,
1766+ location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
1767+ )
1768+ name_to_call = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" )
1769+
1770+ @show name_to_call
1771+ display (MLIR. IR. mmodule ())
1772+
1773+ error (" TODO: implement triton_call" )
1774+ end
1775+
1776+ """
1777+ hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1778+
1779+ Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1780+ with the provided arguments and return a tuple for each result of the call.
1781+
1782+ ```julia-repl
1783+ julia> Reactant.@jit(
1784+ hlo_call(
1785+ \"\"\"
1786+ module {
1787+ func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1788+ %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1789+ return %0 : tensor<3xf32>
1790+ }
1791+ }
1792+ \"\"\" ,
1793+ Reactant.to_rarray(Float32[1, 2, 3]),
1794+ Reactant.to_rarray(Float32[1, 2, 3]),
1795+ )
1796+ )
1797+ (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1798+ ```
1799+ """
1800+ @noinline function hlo_call (
1801+ code,
1802+ args:: Union{TracedRArray,TracedRNumber} ...;
1803+ func_name= " main" ,
1804+ location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1805+ )
1806+ name_to_call = _extract_function (code; func_name, func_op_kind= " func.func" )
1807+
17891808 ftype_attr = MLIR. IR. attr (fn, " function_type" )
17901809 ftype = MLIR. IR. Type (ftype_attr)
17911810
1792- @assert all (Base. Fix2 (isa, Union{TracedRArray,TracedRNumber}), args) " hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
1793- @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name "
1811+ @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name . Expected $(MLIR. IR. ninputs (ftype)) , got $(length (args)) "
17941812
17951813 for (i, arg) in enumerate (args)
17961814 expected_type = MLIR. IR. input (ftype, i)
0 commit comments