Skip to content

Commit ba29516

Browse files
committed
feat: copy tt.func into main module [skip ci]
1 parent 6c1d287 commit ba29516

File tree

2 files changed

+70
-42
lines changed

2 files changed

+70
-42
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,18 @@ function overlayed_pycall_with_triton(
108108

109109
ccinfo = triton.compile(src; target=target, options=options.__dict__)
110110

111-
println(pyconvert(String, ccinfo.asm["source"]))
112-
println(pyconvert(String, ccinfo.asm["ttir"]))
111+
@show ccinfo.metadata
112+
@show ccinfo.asm.keys()
113+
# shared = ccinfo.metadata["shared"]
114+
kernel_name = pyconvert(String, ccinfo.metadata.name)
115+
# cluster_dims = ccinfo.metadata["cluster_dims"]
116+
117+
# println(pyconvert(String, ccinfo.asm["source"]))
118+
# println(pyconvert(String, ccinfo.asm["ttir"]))
119+
120+
res = @opcall triton_call(
121+
pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name
122+
)
113123

114124
return error("TODO: implement triton")
115125
end

src/Ops.jl

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,68 +1743,38 @@ end
17431743
end
17441744

17451745
# Generate a unique name given a module hash and a function name.
1746-
function _hlo_call_name(orig_name, module_suffix)
1747-
return orig_name * "_hlo_call_" * module_suffix
1748-
end
1746+
_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix
17491747

1750-
"""
1751-
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1752-
1753-
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1754-
with the provided arguments and return a tuple for each result of the call.
1755-
1756-
```julia-repl
1757-
julia> Reactant.@jit(
1758-
hlo_call(
1759-
\"\"\"
1760-
module {
1761-
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1762-
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1763-
return %0 : tensor<3xf32>
1764-
}
1765-
}
1766-
\"\"\",
1767-
Reactant.to_rarray(Float32[1, 2, 3]),
1768-
Reactant.to_rarray(Float32[1, 2, 3]),
1769-
)
1770-
)
1771-
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1772-
```
1773-
"""
1774-
@noinline function hlo_call(
1775-
code,
1776-
args...;
1777-
func_name="main",
1778-
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
1748+
function _extract_function(
1749+
code::String; func_name::String="main", func_op_kind::String="func.func"
17791750
)
17801751
module_suffix = string(hash(code); base=16)
1781-
name_to_call = _hlo_call_name(func_name, module_suffix)
1752+
name_to_call = _new_function_name(func_name, module_suffix)
17821753

17831754
current_module = MLIR.IR.mmodule()
17841755
top_level_block = MLIR.IR.body(current_module)
17851756

17861757
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
1787-
17881758
fn = MLIR.IR.lookup(
17891759
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
17901760
)
1761+
17911762
if isnothing(fn)
17921763
new_mod = parse(MLIR.IR.Module, code)
17931764
new_mod_op = MLIR.IR.Operation(new_mod)
17941765
body = MLIR.IR.body(new_mod)
17951766

17961767
operations = collect(MLIR.IR.OperationIterator(body))
17971768
for op in operations
1798-
if MLIR.IR.name(op) == "func.func"
1769+
if MLIR.IR.name(op) == func_op_kind
17991770
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
18001771
if fn_name == func_name
18011772
fn = op
18021773
end
18031774

1804-
new_name = _hlo_call_name(fn_name, module_suffix)
18051775
res = MLIR.IR.LogicalResult(
18061776
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
1807-
fn_name, new_name, new_mod_op
1777+
fn_name, name_to_call, new_mod_op
18081778
),
18091779
)
18101780
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
@@ -1817,7 +1787,7 @@ julia> Reactant.@jit(
18171787
)
18181788

18191789
# Change function name
1820-
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
1790+
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
18211791
end
18221792
end
18231793

@@ -1831,11 +1801,59 @@ julia> Reactant.@jit(
18311801
error("hlo_call: could not find function $func_name in the provided module")
18321802
end
18331803

1804+
return name_to_call
1805+
end
1806+
1807+
function triton_call(
1808+
mlir_code::String,
1809+
args::Union{TracedRArray,TracedRNumber,Number}...;
1810+
func_name::String="main",
1811+
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
1812+
)
1813+
name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1814+
1815+
@show name_to_call
1816+
display(MLIR.IR.mmodule())
1817+
1818+
error("TODO: implement triton_call")
1819+
end
1820+
1821+
"""
1822+
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1823+
1824+
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1825+
with the provided arguments and return a tuple for each result of the call.
1826+
1827+
```julia-repl
1828+
julia> Reactant.@jit(
1829+
hlo_call(
1830+
\"\"\"
1831+
module {
1832+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1833+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1834+
return %0 : tensor<3xf32>
1835+
}
1836+
}
1837+
\"\"\",
1838+
Reactant.to_rarray(Float32[1, 2, 3]),
1839+
Reactant.to_rarray(Float32[1, 2, 3]),
1840+
)
1841+
)
1842+
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1843+
```
1844+
"""
1845+
@noinline function hlo_call(
1846+
code,
1847+
args::Union{TracedRArray,TracedRNumber}...;
1848+
func_name="main",
1849+
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
1850+
)
1851+
name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1852+
18341853
ftype_attr = MLIR.IR.attr(fn, "function_type")
18351854
ftype = MLIR.IR.Type(ftype_attr)
18361855

1837-
@assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
1838-
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"
1856+
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))"
18391857

18401858
for (i, arg) in enumerate(args)
18411859
expected_type = MLIR.IR.input(ftype, i)

0 commit comments

Comments
 (0)