@@ -3130,6 +3130,7 @@ end
31303130 [size (input, i) for i in (length (batch_shape) + 1 ): ndims (input)]. .. ,
31313131 ) for input in inputs
31323132 ]
3133+ argprefix = gensym (" batcharg" )
31333134 mlir_fn_res = Reactant. TracedUtils. make_mlir_fn (
31343135 f,
31353136 (sample_inputs... ,),
@@ -3138,11 +3139,35 @@ end
31383139 false ;
31393140 args_in_result= :none ,
31403141 do_transpose= false ,
3142+ argprefix,
31413143 )
31423144
31433145 func = mlir_fn_res. f
31443146 @assert MLIR. IR. nregions (func) == 1
31453147
3148+ if mlir_fn_res. fnwrapped
3149+ # In the long-term we should be able to do per-argument batching.
3150+ # Rn we simply broadcast_in_dim the arguments to the correct shape.
3151+ final_inputs = TracedRArray[]
3152+ seenargs = Reactant. OrderedIdDict ()
3153+ Reactant. make_tracer (
3154+ seenargs, f, (argprefix, 1 ), Reactant. TracedSetPath; toscalar= false
3155+ )
3156+ for (k, v) in seenargs
3157+ v isa Reactant. TracedType || continue
3158+ bcasted_arg = broadcast_in_dim (
3159+ v,
3160+ collect (Int64, (length (batch_shape) + 1 ): (ndims (v) + length (batch_shape))),
3161+ vcat (batch_shape, collect (Int64, size (v)));
3162+ location,
3163+ )
3164+ push! (final_inputs, bcasted_arg)
3165+ end
3166+ append! (final_inputs, inputs)
3167+ else
3168+ final_inputs = inputs
3169+ end
3170+
31463171 output_types = MLIR. IR. Type[]
31473172 for result in mlir_fn_res. linear_results
31483173 push! (
@@ -3154,7 +3179,7 @@ end
31543179 )
31553180 end
31563181
3157- return batch (inputs , output_types, batch_shape; fn= func, location)
3182+ return batch (final_inputs , output_types, batch_shape; fn= func, location)
31583183end
31593184
31603185@noinline function batch (
0 commit comments