@@ -15,6 +15,7 @@ using ..Reactant:
1515 promote_to, # keep this to avoid breaking external code
1616 broadcast_to_size # keep this to avoid breaking external code
1717using .. Ops: @opcall
18+ using GPUArraysCore: @allowscalar
1819using ReactantCore: ReactantCore
1920using ReactantCore: MissingTracedValue, is_traced, materialize_traced_array
2021
@@ -1086,6 +1087,49 @@ function set!(x, path, tostore; emptypath=false)
10861087 return emptypath && set_paths! (x, ())
10871088end
10881089
1090+ function __elem_apply_loop_condition (idx_ref, fn_ref:: F , res_ref, args_ref, L_ref) where {F}
1091+ return idx_ref[] < L_ref[]
1092+ end
1093+
1094+ function __elem_apply_loop_body (idx_ref, fn_ref:: F , res_ref, args_ref, L_ref) where {F}
1095+ args = args_ref[]
1096+ fn = fn_ref[]
1097+ res = res_ref[]
1098+ idx = idx_ref[] + 1
1099+
1100+ scalar_args = [@allowscalar (arg[idx]) for arg in args]
1101+ @allowscalar res[idx] = fn (scalar_args... )
1102+
1103+ idx_ref[] = idx
1104+ res_ref[] = res
1105+ return nothing
1106+ end
1107+
1108+ function elem_apply_via_while_loop (f, args:: Vararg{Any,Nargs} ) where {Nargs}
1109+ @assert allequal (size .(args)) " All args must have the same size"
1110+ L = length (first (args))
1111+ # flattening the tensors makes the auto-batching pass work nicer
1112+ flat_args = [ReactantCore. materialize_traced_array (vec (arg)) for arg in args]
1113+
1114+ # This wont be a mutating function so we can safely execute it once
1115+ res_tmp = @allowscalar (f ([@allowscalar (arg[1 ]) for arg in flat_args]. .. ))
1116+ result = similar (first (flat_args), Reactant. unwrapped_eltype (res_tmp), L)
1117+
1118+ ind_var = Ref (0 )
1119+ f_ref = Ref (f)
1120+ result_ref = Ref (result)
1121+ args_ref = Ref (flat_args)
1122+ limit_ref = Ref (L)
1123+
1124+ ReactantCore. traced_while (
1125+ __elem_apply_loop_condition,
1126+ __elem_apply_loop_body,
1127+ (ind_var, f_ref, result_ref, args_ref, limit_ref),
1128+ )
1129+
1130+ return ReactantCore. materialize_traced_array (reshape (result, size (first (args))))
1131+ end
1132+
10891133function elem_apply (f, args:: Vararg{Any,Nargs} ) where {Nargs}
10901134 if all (iszero ∘ ndims, args)
10911135 scalar_args = map (args) do arg
@@ -1094,6 +1138,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
10941138 return Reactant. call_with_reactant (f, scalar_args... )
10951139 end
10961140
1141+ # we can expand the scope of this later to support cases where the output
1142+ # doesn't align with `Ops.batch`. For now we just handle cases that would
1143+ # obviously fail with scalarizing the inputs.
1144+ if Reactant. use_overlayed_version (f)
1145+ return elem_apply_via_while_loop (f, args... )
1146+ end
1147+
10971148 argprefix:: Symbol = gensym (" broadcastarg" )
10981149 resprefix:: Symbol = gensym (" broadcastresult" )
10991150 resargprefix:: Symbol = gensym (" broadcastresarg" )
0 commit comments