Skip to content

Commit 933f67a

Browse files
committed
fix: hlo_call
1 parent 26d217f commit 933f67a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/Ops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,7 @@ function _extract_function(
18011801
error("hlo_call: could not find function $func_name in the provided module")
18021802
end
18031803

1804-
return name_to_call
1804+
return fn, name_to_call
18051805
end
18061806

18071807
function triton_call(
@@ -1815,7 +1815,7 @@ function triton_call(
18151815
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18161816
# TODO: other kwargs
18171817
)
1818-
name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1818+
_, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
18191819

18201820
enzymexla.triton_call(
18211821
grid_x.mlir_data,
@@ -1861,7 +1861,7 @@ julia> Reactant.@jit(
18611861
func_name="main",
18621862
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18631863
)
1864-
name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1864+
fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
18651865

18661866
ftype_attr = MLIR.IR.attr(fn, "function_type")
18671867
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)