Skip to content

Commit daf2ea0

Browse files
authored
feat: support closures in batching (#1846)
1 parent f5e61ff commit daf2ea0

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/Ops.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,7 @@ end
31303130
[size(input, i) for i in (length(batch_shape) + 1):ndims(input)]...,
31313131
) for input in inputs
31323132
]
3133+
argprefix = gensym("batcharg")
31333134
mlir_fn_res = Reactant.TracedUtils.make_mlir_fn(
31343135
f,
31353136
(sample_inputs...,),
@@ -3138,11 +3139,35 @@ end
31383139
false;
31393140
args_in_result=:none,
31403141
do_transpose=false,
3142+
argprefix,
31413143
)
31423144

31433145
func = mlir_fn_res.f
31443146
@assert MLIR.IR.nregions(func) == 1
31453147

3148+
if mlir_fn_res.fnwrapped
3149+
# In the long-term we should be able to do per-argument batching.
3150+
# Rn we simply broadcast_in_dim the arguments to the correct shape.
3151+
final_inputs = TracedRArray[]
3152+
seenargs = Reactant.OrderedIdDict()
3153+
Reactant.make_tracer(
3154+
seenargs, f, (argprefix, 1), Reactant.TracedSetPath; toscalar=false
3155+
)
3156+
for (k, v) in seenargs
3157+
v isa Reactant.TracedType || continue
3158+
bcasted_arg = broadcast_in_dim(
3159+
v,
3160+
collect(Int64, (length(batch_shape) + 1):(ndims(v) + length(batch_shape))),
3161+
vcat(batch_shape, collect(Int64, size(v)));
3162+
location,
3163+
)
3164+
push!(final_inputs, bcasted_arg)
3165+
end
3166+
append!(final_inputs, inputs)
3167+
else
3168+
final_inputs = inputs
3169+
end
3170+
31463171
output_types = MLIR.IR.Type[]
31473172
for result in mlir_fn_res.linear_results
31483173
push!(
@@ -3154,7 +3179,7 @@ end
31543179
)
31553180
end
31563181

3157-
return batch(inputs, output_types, batch_shape; fn=func, location)
3182+
return batch(final_inputs, output_types, batch_shape; fn=func, location)
31583183
end
31593184

31603185
@noinline function batch(

test/batching.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,18 @@ end
8585

8686
run_auto_batching_tests(naive_batched_matmul, x, y)
8787
end
88+
89+
function batch_with_closure(x, y)
90+
_fn(x) = x .+ y
91+
return mapslices(_fn, x; dims=2)
92+
end
93+
94+
@testset "Batching with closure" begin
95+
x = Reactant.TestUtils.construct_test_array(Float32, 3, 256, 8)
96+
y = Reactant.TestUtils.construct_test_array(Float32, 256)
97+
98+
x_ra = Reactant.to_rarray(x)
99+
y_ra = Reactant.to_rarray(y)
100+
101+
@test @jit(batch_with_closure(x_ra, y_ra)) batch_with_closure(x, y)
102+
end

0 commit comments

Comments
 (0)