@@ -1705,54 +1705,60 @@ function _extract_function(
17051705 func_name:: String = " main" ,
17061706 func_op_kind:: String = " func.func" ,
17071707 nested_module:: Bool = false ,
1708+ location:: MLIR.IR.Location = MLIR. IR. Location (),
17081709)
17091710 module_suffix = string (hash (code); base= 16 )
1710- name_to_call = _new_function_name (func_name, module_suffix)
1711+ name_to_call = func_name * " _call_" * module_suffix
1712+ mod_name = func_name * " _module_" * module_suffix
1713+ symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
17111714
1712- current_module = MLIR. IR. mmodule ()
17131715 if nested_module
1714- new_module = MLIR. IR. Module ()
1715- push! (MLIR. IR. body (current_module), MLIR. IR. Operation (new_module, true ))
1716- current_module = new_module
1717- end
1718- top_level_block = MLIR. IR. body (current_module)
1716+ region = MLIR. IR. Region ()
1717+ push! (region, MLIR. IR. Block ())
1718+ moduleop = MLIR. Dialects. builtin. module_ (;
1719+ location, bodyRegion= region, sym_name= mod_name
1720+ )
1721+ MLIR. IR. rmfromparent! (moduleop)
1722+ push! (MLIR. IR. body (MLIR. IR. mmodule ()), moduleop) # insert into parent module
17191723
1720- symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1721- fn = MLIR. IR. lookup (
1722- MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
1723- )
1724+ top_level_block = MLIR. IR. Block (
1725+ MLIR. API. mlirModuleGetBody (MLIR. API. mlirModuleFromOperation (moduleop)), false
1726+ )
1727+ fn = nothing
1728+ else
1729+ current_module = MLIR. IR. mmodule ()
1730+ moduleop = MLIR. IR. Operation (current_module)
1731+ top_level_block = MLIR. IR. body (current_module)
1732+ fn = MLIR. IR. lookup (MLIR. IR. SymbolTable (moduleop), name_to_call)
1733+ end
17241734
17251735 if isnothing (fn)
17261736 new_mod = parse (MLIR. IR. Module, code)
17271737 new_mod_op = MLIR. IR. Operation (new_mod)
17281738 body = MLIR. IR. body (new_mod)
17291739
17301740 operations = collect (MLIR. IR. OperationIterator (body))
1731- for op in operations
1732- if MLIR. IR. name (op) == func_op_kind
1733- fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1734- if fn_name == func_name
1735- fn = op
1736- end
1741+ idx = Base. findfirst (op -> MLIR. IR. name (op) == func_op_kind, operations)
1742+ @assert idx != = nothing
1743+ op = operations[idx]
17371744
1738- res = MLIR. IR. LogicalResult (
1739- MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1740- fn_name, name_to_call, new_mod_op
1741- ),
1742- )
1743- @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
1744-
1745- # Set function private
1746- MLIR. IR. attr! (
1747- op,
1748- MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1749- MLIR. IR. Attribute (" private" ),
1750- )
1751-
1752- # Change function name
1753- MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call))
1754- end
1755- end
1745+ fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1746+ fn_name == func_name && (fn = op)
1747+
1748+ res = MLIR. IR. LogicalResult (
1749+ MLIR. API. mlirSymbolTableReplaceAllSymbolUses (fn_name, name_to_call, new_mod_op)
1750+ )
1751+ @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
1752+
1753+ # Set function private
1754+ MLIR. IR. attr! (
1755+ op,
1756+ MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1757+ MLIR. IR. Attribute (" private" ),
1758+ )
1759+
1760+ # Change function name
1761+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call))
17561762
17571763 for op in operations
17581764 MLIR. IR. rmfromparent! (op)
@@ -1764,7 +1770,7 @@ function _extract_function(
17641770 error (" hlo_call: could not find function $func_name in the provided module" )
17651771 end
17661772
1767- return fn, name_to_call
1773+ return fn, name_to_call, mod_name
17681774end
17691775
17701776function triton_call (
@@ -1778,8 +1784,8 @@ function triton_call(
17781784 location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
17791785 # TODO : other kwargs
17801786)
1781- _, name_to_call = _extract_function (
1782- mlir_code; func_name, func_op_kind= " tt.func" , nested_module= true
1787+ _, name_to_call, mod_name = _extract_function (
1788+ mlir_code; func_name, func_op_kind= " tt.func" , nested_module= true , location
17831789 )
17841790
17851791 enzymexla. triton_call (
@@ -1788,7 +1794,9 @@ function triton_call(
17881794 grid_z. mlir_data,
17891795 shmem. mlir_data,
17901796 [Reactant. TracedUtils. get_mlir_data (a) for a in args];
1791- fn= MLIR. IR. FlatSymbolRefAttribute (name_to_call),
1797+ fn= MLIR. IR. SymbolRefAttribute (
1798+ mod_name, MLIR. IR. Attribute[MLIR. IR. FlatSymbolRefAttribute (name_to_call)]
1799+ ),
17921800 result_0= MLIR. IR. Type[],
17931801 location,
17941802 )
@@ -1826,7 +1834,9 @@ julia> Reactant.@jit(
18261834 func_name= " main" ,
18271835 location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
18281836)
1829- fn, name_to_call = _extract_function (code; func_name, func_op_kind= " func.func" )
1837+ fn, name_to_call, _ = _extract_function (
1838+ code; func_name, func_op_kind= " func.func" , location
1839+ )
18301840
18311841 ftype_attr = MLIR. IR. attr (fn, " function_type" )
18321842 ftype = MLIR. IR. Type (ftype_attr)
0 commit comments