33# Julia and Reactant semantics should be considered on the higher abstractions that use these
44module Ops
55using .. MLIR: MLIR
6- using .. MLIR. Dialects: stablehlo, chlo, enzyme, enzymexla
6+ using .. MLIR. Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext
77using .. Reactant:
88 Reactant,
99 TracedRArray,
@@ -1704,32 +1704,52 @@ function _extract_function(
17041704 code:: String ;
17051705 func_name:: String = " main" ,
17061706 func_op_kind:: String = " func.func" ,
1707- nested_module:: Bool = false ,
17081707 location:: MLIR.IR.Location = MLIR. IR. Location (),
17091708)
17101709 module_suffix = string (hash (code); base= 16 )
17111710 name_to_call = func_name * " _call_" * module_suffix
17121711 mod_name = func_name * " _module_" * module_suffix
17131712 symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
17141713
1715- if nested_module
1714+ use_ttext_module = split (func_op_kind, " ." )[1 ] == " tt"
1715+
1716+ if use_ttext_module
1717+ tt_mod_name = func_name * " _tt_module_" * module_suffix
1718+ tt_region = MLIR. IR. Region ()
1719+ tt_block = MLIR. IR. Block ()
1720+ push! (tt_region, tt_block)
1721+ triton_mod_op = triton_ext. module_ (;
1722+ location, bodyRegion= tt_region, sym_name= tt_mod_name
1723+ )
1724+ MLIR. IR. rmfromparent! (triton_mod_op)
1725+ push! (MLIR. IR. body (MLIR. IR. mmodule ()), triton_mod_op) # insert into parent module
1726+
17161727 region = MLIR. IR. Region ()
17171728 push! (region, MLIR. IR. Block ())
17181729 moduleop = MLIR. Dialects. builtin. module_ (;
17191730 location, bodyRegion= region, sym_name= mod_name
17201731 )
17211732 MLIR. IR. rmfromparent! (moduleop)
1722- push! (MLIR . IR . body (MLIR . IR . mmodule ()) , moduleop) # insert into parent module
1733+ push! (tt_block , moduleop) # insert into triton module
17231734
17241735 top_level_block = MLIR. IR. Block (
17251736 MLIR. API. mlirModuleGetBody (MLIR. API. mlirModuleFromOperation (moduleop)), false
17261737 )
17271738 fn = nothing
1739+
1740+ symref = MLIR. IR. SymbolRefAttribute (
1741+ tt_mod_name,
1742+ MLIR. IR. Attribute[
1743+ MLIR. IR. FlatSymbolRefAttribute (mod_name),
1744+ MLIR. IR. FlatSymbolRefAttribute (name_to_call),
1745+ ],
1746+ )
17281747 else
17291748 current_module = MLIR. IR. mmodule ()
17301749 moduleop = MLIR. IR. Operation (current_module)
17311750 top_level_block = MLIR. IR. body (current_module)
17321751 fn = MLIR. IR. lookup (MLIR. IR. SymbolTable (moduleop), name_to_call)
1752+ symref = MLIR. IR. FlatSymbolRefAttribute (name_to_call)
17331753 end
17341754
17351755 if isnothing (fn)
@@ -1750,12 +1770,14 @@ function _extract_function(
17501770 )
17511771 @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
17521772
1753- # Set function private
1754- MLIR. IR. attr! (
1755- op,
1756- MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1757- MLIR. IR. Attribute (" private" ),
1758- )
1773+ if ! use_ttext_module
1774+ # Set function private
1775+ MLIR. IR. attr! (
1776+ op,
1777+ MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1778+ MLIR. IR. Attribute (" private" ),
1779+ )
1780+ end
17591781
17601782 # Change function name
17611783 MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call))
@@ -1770,7 +1792,7 @@ function _extract_function(
17701792 error (" hlo_call: could not find function $func_name in the provided module" )
17711793 end
17721794
1773- return fn, name_to_call, mod_name
1795+ return fn, symref
17741796end
17751797
17761798function triton_call (
@@ -1784,19 +1806,15 @@ function triton_call(
17841806 location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
17851807 # TODO : other kwargs
17861808)
1787- _, name_to_call, mod_name = _extract_function (
1788- mlir_code; func_name, func_op_kind= " tt.func" , nested_module= true , location
1789- )
1809+ _, symref = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" , location)
17901810
1791- enzymexla . triton_call (
1811+ triton_ext . call (
17921812 grid_x. mlir_data,
17931813 grid_y. mlir_data,
17941814 grid_z. mlir_data,
17951815 shmem. mlir_data,
17961816 [Reactant. TracedUtils. get_mlir_data (a) for a in args];
1797- fn= MLIR. IR. SymbolRefAttribute (
1798- mod_name, MLIR. IR. Attribute[MLIR. IR. FlatSymbolRefAttribute (name_to_call)]
1799- ),
1817+ fn= symref,
18001818 result_0= MLIR. IR. Type[],
18011819 location,
18021820 )
@@ -1834,9 +1852,7 @@ julia> Reactant.@jit(
18341852 func_name= " main" ,
18351853 location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
18361854)
1837- fn, name_to_call, _ = _extract_function (
1838- code; func_name, func_op_kind= " func.func" , location
1839- )
1855+ fn, symref = _extract_function (code; func_name, func_op_kind= " func.func" , location)
18401856
18411857 ftype_attr = MLIR. IR. attr (fn, " function_type" )
18421858 ftype = MLIR. IR. Type (ftype_attr)
@@ -1853,7 +1869,7 @@ julia> Reactant.@jit(
18531869 call = MLIR. Dialects. func. call (
18541870 operands;
18551871 result_0= [MLIR. IR. result (ftype, i) for i in 1 : MLIR. IR. nresults (ftype)],
1856- callee= MLIR . IR . FlatSymbolRefAttribute (name_to_call) ,
1872+ callee= symref ,
18571873 location,
18581874 )
18591875
0 commit comments