Skip to content

Commit e92543f

Browse files
committed
fix: hlo_call
1 parent bf973f6 commit e92543f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/Ops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,7 +1756,7 @@ function _extract_function(
17561756
error("hlo_call: could not find function $func_name in the provided module")
17571757
end
17581758

1759-
return name_to_call
1759+
return fn, name_to_call
17601760
end
17611761

17621762
function triton_call(
@@ -1770,7 +1770,7 @@ function triton_call(
17701770
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
17711771
# TODO: other kwargs
17721772
)
1773-
name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1773+
_, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
17741774

17751775
enzymexla.triton_call(
17761776
grid_x.mlir_data,
@@ -1816,7 +1816,7 @@ julia> Reactant.@jit(
18161816
func_name="main",
18171817
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18181818
)
1819-
name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1819+
fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
18201820

18211821
ftype_attr = MLIR.IR.attr(fn, "function_type")
18221822
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)