Skip to content

Commit 4c87ed7

Browse files
authored
Merge branch 'main' into probprog-hmc
2 parents 8560867 + 78cf63e commit 4c87ed7

File tree

5 files changed

+86
-6
lines changed

5 files changed

+86
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.176"
4+
version = "0.2.177"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.258"
108+
Reactant_jll = "0.0.259"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "c9ad2b9fadb847b303202e2fc5239516bb7f7a92"
7+
ENZYMEXLA_COMMIT = "47d57c1cea7b24e210ad75aee6e7c3f93d89ff78"
88

99
ENZYMEXLA_SHA256 = ""
1010

src/Overlay.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ end
187187
end
188188

189189
@reactant_overlay @noinline function Base.map(f, x::AbstractArray, ys::AbstractArray...)
190-
if use_overlayed_version(x) || looped_any(use_overlayed_version, ys)
190+
if (
191+
use_overlayed_version(x) ||
192+
use_overlayed_version(f) ||
193+
looped_any(use_overlayed_version, ys)
194+
)
191195
return TracedRArrayOverrides.overloaded_map(f, x, ys...)
192196
else
193197
return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...)
@@ -200,6 +204,7 @@ end
200204
if (
201205
use_overlayed_version(y) ||
202206
use_overlayed_version(x) ||
207+
use_overlayed_version(f) ||
203208
looped_any(use_overlayed_version, xs)
204209
)
205210
return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...)
@@ -209,15 +214,15 @@ end
209214
end
210215

211216
@reactant_overlay @noinline function Base._all(f, x::AbstractArray, dims)
212-
if use_overlayed_version(x)
217+
if use_overlayed_version(x) || use_overlayed_version(f)
213218
return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims)
214219
else
215220
return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims)
216221
end
217222
end
218223

219224
@reactant_overlay @noinline function Base._any(f, x::AbstractArray, dims)
220-
if use_overlayed_version(x)
225+
if use_overlayed_version(x) || use_overlayed_version(f)
221226
return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims)
222227
else
223228
return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims)

src/TracedUtils.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using ..Reactant:
1515
promote_to, # keep this to avoid breaking external code
1616
broadcast_to_size # keep this to avoid breaking external code
1717
using ..Ops: @opcall
18+
using GPUArraysCore: @allowscalar
1819
using ReactantCore: ReactantCore
1920
using ReactantCore: MissingTracedValue, is_traced, materialize_traced_array
2021

@@ -1086,6 +1087,49 @@ function set!(x, path, tostore; emptypath=false)
10861087
return emptypath && set_paths!(x, ())
10871088
end
10881089

1090+
function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
1091+
return idx_ref[] < L_ref[]
1092+
end
1093+
1094+
function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
1095+
args = args_ref[]
1096+
fn = fn_ref[]
1097+
res = res_ref[]
1098+
idx = idx_ref[] + 1
1099+
1100+
scalar_args = [@allowscalar(arg[idx]) for arg in args]
1101+
@allowscalar res[idx] = fn(scalar_args...)
1102+
1103+
idx_ref[] = idx
1104+
res_ref[] = res
1105+
return nothing
1106+
end
1107+
1108+
function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
1109+
@assert allequal(size.(args)) "All args must have the same size"
1110+
L = length(first(args))
1111+
# flattening the tensors makes the auto-batching pass work nicer
1112+
flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args]
1113+
1114+
# This wont be a mutating function so we can safely execute it once
1115+
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
1116+
result = similar(first(flat_args), Reactant.unwrapped_eltype(res_tmp), L)
1117+
1118+
ind_var = Ref(0)
1119+
f_ref = Ref(f)
1120+
result_ref = Ref(result)
1121+
args_ref = Ref(flat_args)
1122+
limit_ref = Ref(L)
1123+
1124+
ReactantCore.traced_while(
1125+
__elem_apply_loop_condition,
1126+
__elem_apply_loop_body,
1127+
(ind_var, f_ref, result_ref, args_ref, limit_ref),
1128+
)
1129+
1130+
return ReactantCore.materialize_traced_array(reshape(result, size(first(args))))
1131+
end
1132+
10891133
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
10901134
if all(iszero ndims, args)
10911135
scalar_args = map(args) do arg
@@ -1094,6 +1138,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
10941138
return Reactant.call_with_reactant(f, scalar_args...)
10951139
end
10961140

1141+
# we can expand the scope of this later to support cases where the output
1142+
# doesn't align with `Ops.batch`. For now we just handle cases that would
1143+
# obviously fail with scalarizing the inputs.
1144+
if Reactant.use_overlayed_version(f)
1145+
return elem_apply_via_while_loop(f, args...)
1146+
end
1147+
10971148
argprefix::Symbol = gensym("broadcastarg")
10981149
resprefix::Symbol = gensym("broadcastresult")
10991150
resargprefix::Symbol = gensym("broadcastresarg")

test/batching.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,27 @@ end
100100

101101
@test @jit(batch_with_closure(x_ra, y_ra)) batch_with_closure(x, y)
102102
end
103+
104+
function map_with_scalar_indexing(i, x, y)
105+
c = max(x[i], y[i])
106+
return x[i] + y[i] + c
107+
end
108+
109+
function mctr(f, range, x, y)
110+
f2(i) = f(i, x, y)
111+
return map(f2, range)
112+
end
113+
114+
@testset "map with scalar indexing" begin
115+
input1 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
116+
input2 = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 10))
117+
118+
hlo = @code_hlo optimize = false mctr(map_with_scalar_indexing, 1:8, input1, input2)
119+
@test contains(repr(hlo), "stablehlo.while")
120+
hlo = @code_hlo optimize = true mctr(map_with_scalar_indexing, 1:8, input1, input2)
121+
@test !contains(repr(hlo), "stablehlo.while")
122+
123+
res_ra = @jit mctr(map_with_scalar_indexing, 1:8, input1, input2)
124+
res = mctr(map_with_scalar_indexing, 1:8, Array(input1), Array(input2))
125+
@test res_ra res
126+
end

0 commit comments

Comments
 (0)