Skip to content

Commit 0dac900

Browse files
authored
feat: conditionally lower elem_apply to a for loop (#1816)
1 parent 953a1ff commit 0dac900

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

src/Overlay.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ end
187187
end
188188

189189
@reactant_overlay @noinline function Base.map(f, x::AbstractArray, ys::AbstractArray...)
190-
if use_overlayed_version(x) || looped_any(use_overlayed_version, ys)
190+
if (
191+
use_overlayed_version(x) ||
192+
use_overlayed_version(f) ||
193+
looped_any(use_overlayed_version, ys)
194+
)
191195
return TracedRArrayOverrides.overloaded_map(f, x, ys...)
192196
else
193197
return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...)
@@ -200,6 +204,7 @@ end
200204
if (
201205
use_overlayed_version(y) ||
202206
use_overlayed_version(x) ||
207+
use_overlayed_version(f) ||
203208
looped_any(use_overlayed_version, xs)
204209
)
205210
return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...)
@@ -209,15 +214,15 @@ end
209214
end
210215

211216
@reactant_overlay @noinline function Base._all(f, x::AbstractArray, dims)
212-
if use_overlayed_version(x)
217+
if use_overlayed_version(x) || use_overlayed_version(f)
213218
return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims)
214219
else
215220
return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims)
216221
end
217222
end
218223

219224
@reactant_overlay @noinline function Base._any(f, x::AbstractArray, dims)
220-
if use_overlayed_version(x)
225+
if use_overlayed_version(x) || use_overlayed_version(f)
221226
return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims)
222227
else
223228
return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims)

src/TracedUtils.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using ..Reactant:
1515
promote_to, # keep this to avoid breaking external code
1616
broadcast_to_size # keep this to avoid breaking external code
1717
using ..Ops: @opcall
18+
using GPUArraysCore: @allowscalar
1819
using ReactantCore: ReactantCore
1920
using ReactantCore: MissingTracedValue, is_traced, materialize_traced_array
2021

@@ -1086,6 +1087,49 @@ function set!(x, path, tostore; emptypath=false)
10861087
return emptypath && set_paths!(x, ())
10871088
end
10881089

1090+
function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
1091+
return idx_ref[] < L_ref[]
1092+
end
1093+
1094+
function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
1095+
args = args_ref[]
1096+
fn = fn_ref[]
1097+
res = res_ref[]
1098+
idx = idx_ref[] + 1
1099+
1100+
scalar_args = [@allowscalar(arg[idx]) for arg in args]
1101+
@allowscalar res[idx] = fn(scalar_args...)
1102+
1103+
idx_ref[] = idx
1104+
res_ref[] = res
1105+
return nothing
1106+
end
1107+
1108+
function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
1109+
@assert allequal(size.(args)) "All args must have the same size"
1110+
L = length(first(args))
1111+
# flattening the tensors makes the auto-batching pass work nicer
1112+
flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args]
1113+
1114+
# This wont be a mutating function so we can safely execute it once
1115+
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
1116+
result = similar(first(flat_args), Reactant.unwrapped_eltype(res_tmp), L)
1117+
1118+
ind_var = Ref(0)
1119+
f_ref = Ref(f)
1120+
result_ref = Ref(result)
1121+
args_ref = Ref(flat_args)
1122+
limit_ref = Ref(L)
1123+
1124+
ReactantCore.traced_while(
1125+
__elem_apply_loop_condition,
1126+
__elem_apply_loop_body,
1127+
(ind_var, f_ref, result_ref, args_ref, limit_ref),
1128+
)
1129+
1130+
return ReactantCore.materialize_traced_array(reshape(result, size(first(args))))
1131+
end
1132+
10891133
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
10901134
if all(iszero ndims, args)
10911135
scalar_args = map(args) do arg
@@ -1094,6 +1138,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
10941138
return Reactant.call_with_reactant(f, scalar_args...)
10951139
end
10961140

1141+
# we can expand the scope of this later to support cases where the output
1142+
# doesn't align with `Ops.batch`. For now we just handle cases that would
1143+
# obviously fail with scalarizing the inputs.
1144+
if Reactant.use_overlayed_version(f)
1145+
return elem_apply_via_while_loop(f, args...)
1146+
end
1147+
10971148
argprefix::Symbol = gensym("broadcastarg")
10981149
resprefix::Symbol = gensym("broadcastresult")
10991150
resargprefix::Symbol = gensym("broadcastresarg")

test/batching.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,27 @@ end
100100

101101
@test @jit(batch_with_closure(x_ra, y_ra)) batch_with_closure(x, y)
102102
end
103+
104+
function map_with_scalar_indexing(i, x, y)
105+
c = max(x[i], y[i])
106+
return x[i] + y[i] + c
107+
end
108+
109+
function mctr(f, range, x, y)
110+
f2(i) = f(i, x, y)
111+
return map(f2, range)
112+
end
113+
114+
@testset "map with scalar indexing" begin
115+
input1 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
116+
input2 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
117+
118+
hlo = @code_hlo optimize = false mctr(map_with_scalar_indexing, 1:8, input1, input2)
119+
@test contains(repr(hlo), "stablehlo.while")
120+
hlo = @code_hlo optimize = true mctr(map_with_scalar_indexing, 1:8, input1, input2)
121+
@test !contains(repr(hlo), "stablehlo.while")
122+
123+
res_ra = @jit mctr(map_with_scalar_indexing, 1:8, input1, input2)
124+
res = mctr(map_with_scalar_indexing, 1:8, Array(input1), Array(input2))
125+
@test res_ra res
126+
end

0 commit comments

Comments
 (0)