@@ -1750,54 +1750,60 @@ function _extract_function(
17501750 func_name:: String = " main" ,
17511751 func_op_kind:: String = " func.func" ,
17521752 nested_module:: Bool = false ,
1753+ location:: MLIR.IR.Location = MLIR. IR. Location (),
17531754)
17541755 module_suffix = string (hash (code); base= 16 )
1755- name_to_call = _new_function_name (func_name, module_suffix)
1756+ name_to_call = func_name * " _call_" * module_suffix
1757+ mod_name = func_name * " _module_" * module_suffix
1758+ symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
17561759
1757- current_module = MLIR. IR. mmodule ()
17581760 if nested_module
1759- new_module = MLIR. IR. Module ()
1760- push! (MLIR. IR. body (current_module), MLIR. IR. Operation (new_module, true ))
1761- current_module = new_module
1762- end
1763- top_level_block = MLIR. IR. body (current_module)
1761+ region = MLIR. IR. Region ()
1762+ push! (region, MLIR. IR. Block ())
1763+ moduleop = MLIR. Dialects. builtin. module_ (;
1764+ location, bodyRegion= region, sym_name= mod_name
1765+ )
1766+ MLIR. IR. rmfromparent! (moduleop)
1767+ push! (MLIR. IR. body (MLIR. IR. mmodule ()), moduleop) # insert into parent module
17641768
1765- symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1766- fn = MLIR. IR. lookup (
1767- MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
1768- )
1769+ top_level_block = MLIR. IR. Block (
1770+ MLIR. API. mlirModuleGetBody (MLIR. API. mlirModuleFromOperation (moduleop)), false
1771+ )
1772+ fn = nothing
1773+ else
1774+ current_module = MLIR. IR. mmodule ()
1775+ moduleop = MLIR. IR. Operation (current_module)
1776+ top_level_block = MLIR. IR. body (current_module)
1777+ fn = MLIR. IR. lookup (MLIR. IR. SymbolTable (moduleop), name_to_call)
1778+ end
17691779
17701780 if isnothing (fn)
17711781 new_mod = parse (MLIR. IR. Module, code)
17721782 new_mod_op = MLIR. IR. Operation (new_mod)
17731783 body = MLIR. IR. body (new_mod)
17741784
17751785 operations = collect (MLIR. IR. OperationIterator (body))
1776- for op in operations
1777- if MLIR. IR. name (op) == func_op_kind
1778- fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1779- if fn_name == func_name
1780- fn = op
1781- end
1786+ idx = Base. findfirst (op -> MLIR. IR. name (op) == func_op_kind, operations)
1787+ @assert idx != = nothing
1788+ op = operations[idx]
17821789
1783- res = MLIR. IR. LogicalResult (
1784- MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1785- fn_name, name_to_call, new_mod_op
1786- ),
1787- )
1788- @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
1789-
1790- # Set function private
1791- MLIR. IR. attr! (
1792- op,
1793- MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1794- MLIR. IR. Attribute (" private" ),
1795- )
1796-
1797- # Change function name
1798- MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call))
1799- end
1800- end
1790+ fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1791+ fn_name == func_name && (fn = op)
1792+
1793+ res = MLIR. IR. LogicalResult (
1794+ MLIR. API. mlirSymbolTableReplaceAllSymbolUses (fn_name, name_to_call, new_mod_op)
1795+ )
1796+ @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
1797+
1798+ # Set function private
1799+ MLIR. IR. attr! (
1800+ op,
1801+ MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1802+ MLIR. IR. Attribute (" private" ),
1803+ )
1804+
1805+ # Change function name
1806+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call))
18011807
18021808 for op in operations
18031809 MLIR. IR. rmfromparent! (op)
@@ -1809,7 +1815,7 @@ function _extract_function(
18091815 error (" hlo_call: could not find function $func_name in the provided module" )
18101816 end
18111817
1812- return fn, name_to_call
1818+ return fn, name_to_call, mod_name
18131819end
18141820
18151821function triton_call (
@@ -1823,8 +1829,8 @@ function triton_call(
18231829 location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
18241830 # TODO : other kwargs
18251831)
1826- _, name_to_call = _extract_function (
1827- mlir_code; func_name, func_op_kind= " tt.func" , nested_module= true
1832+ _, name_to_call, mod_name = _extract_function (
1833+ mlir_code; func_name, func_op_kind= " tt.func" , nested_module= true , location
18281834 )
18291835
18301836 enzymexla. triton_call (
@@ -1833,7 +1839,9 @@ function triton_call(
18331839 grid_z. mlir_data,
18341840 shmem. mlir_data,
18351841 [Reactant. TracedUtils. get_mlir_data (a) for a in args];
1836- fn= MLIR. IR. FlatSymbolRefAttribute (name_to_call),
1842+ fn= MLIR. IR. SymbolRefAttribute (
1843+ mod_name, MLIR. IR. Attribute[MLIR. IR. FlatSymbolRefAttribute (name_to_call)]
1844+ ),
18371845 result_0= MLIR. IR. Type[],
18381846 location,
18391847 )
@@ -1871,7 +1879,9 @@ julia> Reactant.@jit(
18711879 func_name= " main" ,
18721880 location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
18731881)
1874- fn, name_to_call = _extract_function (code; func_name, func_op_kind= " func.func" )
1882+ fn, name_to_call, _ = _extract_function (
1883+ code; func_name, func_op_kind= " func.func" , location
1884+ )
18751885
18761886 ftype_attr = MLIR. IR. attr (fn, " function_type" )
18771887 ftype = MLIR. IR. Type (ftype_attr)
0 commit comments