Skip to content

Commit 560ae49

Browse files
committed
automated tracing of ProbProgTrace and Constraint structs
1 parent b225dc9 commit 560ae49

File tree

7 files changed

+102
-92
lines changed

7 files changed

+102
-92
lines changed

src/Reactant.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,13 @@ include("Tracing.jl")
246246
include("Compiler.jl")
247247

248248
include("Overlay.jl")
249-
include("probprog/ProbProg.jl")
250249

251250
# Serialization
252251
include("serialization/Serialization.jl")
253252

253+
# ProbProg
254+
include("probprog/ProbProg.jl")
255+
254256
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
255257
export ConcreteRArray,
256258
ConcreteRNumber,

src/probprog/Modeling.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
9292
seed_buffer = only(rng.seed.data).buffer
9393
GC.@preserve seed_buffer begin
9494
t, _, _ = compiled_fn(rng, f, args...)
95-
trace = from_trace_tensor(t)
95+
trace = ProbProgTrace(t)
9696
end
9797

9898
return trace, trace.weight
@@ -146,37 +146,30 @@ end
146146

147147
# Gen-like helper function.
148148
function generate_(
149-
rng::AbstractRNG,
150-
f::Function,
151-
args::Vararg{Any,Nargs};
152-
constraint::Constraint=Constraint(),
149+
rng::AbstractRNG, constraint::Constraint, f::Function, args::Vararg{Any,Nargs}
153150
) where {Nargs}
154151
trace = nothing
155152

156-
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
157-
158153
constrained_addresses = extract_addresses(constraint)
159154

160-
function wrapper_fn(rng, constraint_ptr, args...)
161-
return generate(rng, f, args...; constraint_ptr, constrained_addresses)
162-
end
163-
164-
compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...)
155+
compiled_fn = @compile optimize = :probprog generate(
156+
rng, constraint, f, args...; constrained_addresses
157+
)
165158

166159
seed_buffer = only(rng.seed.data).buffer
167160
GC.@preserve seed_buffer constraint begin
168-
t, _, _ = compiled_fn(rng, constraint_ptr, args...)
169-
trace = from_trace_tensor(t)
161+
t, _, _ = compiled_fn(rng, constraint, f, args...)
162+
trace = ProbProgTrace(t)
170163
end
171164

172165
return trace, trace.weight
173166
end
174167

175168
function generate(
176169
rng::AbstractRNG,
170+
constraint,
177171
f::Function,
178172
args::Vararg{Any,Nargs};
179-
constraint_ptr::TracedRNumber,
180173
constrained_addresses::Set{Address},
181174
) where {Nargs}
182175
args = (rng, args...)
@@ -193,7 +186,7 @@ function generate(
193186

194187
constraint_val = MLIR.IR.result(
195188
MLIR.Dialects.builtin.unrealized_conversion_cast(
196-
[TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty]
189+
[TracedUtils.get_mlir_data(constraint)]; outputs=[constraint_ty]
197190
),
198191
1,
199192
)

src/probprog/Types.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Base: ReentrantLock
2+
using ..Reactant: AbstractConcreteNumber, AbstractConcreteArray
23

34
mutable struct ProbProgTrace
45
choices::Dict{Symbol,Any}
@@ -9,6 +10,9 @@ mutable struct ProbProgTrace
910
function ProbProgTrace()
1011
return new(Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}())
1112
end
13+
function ProbProgTrace(x::Union{AbstractConcreteNumber,AbstractConcreteArray})
14+
return convert(ProbProgTrace, x)
15+
end
1216
end
1317

1418
struct Address
@@ -42,6 +46,9 @@ mutable struct Constraint <: AbstractDict{Address,Any}
4246

4347
Constraint() = new(Dict{Address,Any}())
4448
Constraint(d::Dict{Address,Any}) = new(d)
49+
function Constraint(x::Union{AbstractConcreteNumber,AbstractConcreteArray})
50+
return convert(Constraint, x)
51+
end
4552
end
4653

4754
Base.getindex(c::Constraint, k::Address) = c.dict[k]

src/probprog/Utils.jl

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@ using ..Reactant:
33
TracedUtils,
44
Ops,
55
TracedRArray,
6+
TracedRNumber,
67
Compiler,
78
OrderedIdDict,
8-
make_tracer,
99
TracedToTypes,
10-
TracedTrack,
1110
TracedType,
12-
TracedSetPath
13-
import ..Reactant: promote_to
11+
TracedTrack,
12+
TracedSetPath,
13+
ConcreteToTraced,
14+
AbstractConcreteArray,
15+
XLA,
16+
Sharding,
17+
to_number
18+
import ..Reactant: promote_to, make_tracer
19+
import ..Compiler: donate_argument!
1420

1521
"""
1622
process_probprog_function(f, args, op_name)
@@ -171,30 +177,59 @@ function process_probprog_outputs(
171177
return traced_result
172178
end
173179

174-
to_trace_tensor(t::ProbProgTrace) = promote_to(TracedRArray{UInt64,0}, t)
180+
function promote_to(::Type{TracedRArray{UInt64,0}}, t::Union{ProbProgTrace,Constraint})
181+
return Ops.fill(reinterpret(UInt64, pointer_from_objref(t)), Int64[])
182+
end
175183

176-
function from_trace_tensor(trace_tensor)
177-
while !isready(trace_tensor)
184+
function Base.convert(
185+
::Type{T}, x::AbstractConcreteArray
186+
) where {T<:Union{ProbProgTrace,Constraint}}
187+
while !isready(x)
178188
yield()
179189
end
180-
return unsafe_pointer_to_objref(Ptr{Any}(Array(trace_tensor)[1]))::ProbProgTrace
190+
return unsafe_pointer_to_objref(Ptr{Any}(collect(x)[1]))::T
181191
end
182192

183-
function promote_to(::Type{TracedRArray{UInt64,0}}, t::ProbProgTrace)
184-
ptr = reinterpret(UInt64, pointer_from_objref(t))
185-
return Ops.fill(ptr, Int64[])
193+
function Base.convert(
194+
::Type{T}, x::AbstractConcreteNumber
195+
) where {T<:Union{ProbProgTrace,Constraint}}
196+
while !isready(x)
197+
yield()
198+
end
199+
return unsafe_pointer_to_objref(Ptr{Any}(to_number(x)))::T
186200
end
187201

188-
to_constraint_tensor(c::Constraint) = promote_to(TracedRArray{UInt64,0}, c)
189-
190-
function from_constraint_tensor(constraint_tensor)
191-
while !isready(constraint_tensor)
192-
yield()
202+
function Base.getproperty(t::Union{ProbProgTrace,Constraint}, s::Symbol)
203+
if s === :data
204+
return ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(t))).data
205+
else
206+
return getfield(t, s)
193207
end
194-
return unsafe_pointer_to_objref(Ptr{Any}(Array(constraint_tensor)[1]))::Constraint
195208
end
196209

197-
function promote_to(::Type{TracedRArray{UInt64,0}}, c::Constraint)
198-
ptr = reinterpret(UInt64, pointer_from_objref(c))
199-
return Ops.fill(ptr, Int64[])
210+
function donate_argument!(
211+
::Any, ::Union{ProbProgTrace,Constraint}, ::Int, ::Any, ::Any
212+
)
213+
return nothing
214+
end
215+
216+
Base.@nospecializeinfer function make_tracer(
217+
seen,
218+
@nospecialize(prev::Union{ProbProgTrace,Constraint}),
219+
@nospecialize(path),
220+
mode;
221+
@nospecialize(sharding = Sharding.NoSharding()),
222+
kwargs...,
223+
)
224+
if mode == ConcreteToTraced
225+
haskey(seen, prev) && return seen[prev]::TracedRNumber{UInt64}
226+
result = TracedRNumber{UInt64}((path,), nothing)
227+
seen[prev] = result
228+
return result
229+
elseif mode == TracedToTypes
230+
push!(path, typeof(prev))
231+
return nothing
232+
else
233+
error("Unsupported mode for $(typeof(prev)): $mode")
234+
end
200235
end

test/probprog/generate.jl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
rng = ReactantRNG(seed)
3535
μ = Reactant.ConcreteRNumber(0.0)
3636
σ = Reactant.ConcreteRNumber(1.0)
37-
trace, weight = ProbProg.generate_(rng, model, μ, σ, shape)
37+
trace, weight = ProbProg.generate_(rng, ProbProg.Constraint(), model, μ, σ, shape)
3838
@test mean(trace.retval[1]) 0.0 atol = 0.05 rtol = 0.05
3939
end
4040

@@ -47,7 +47,7 @@ end
4747

4848
constraint = ProbProg.Constraint(:s => (fill(0.1, shape),))
4949

50-
trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint)
50+
trace, weight = ProbProg.generate_(rng, constraint, model, μ, σ, shape)
5151

5252
@test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1]
5353

@@ -72,7 +72,7 @@ end
7272
:u => :y => (fill(0.3, shape),),
7373
)
7474

75-
trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint)
75+
trace, weight = ProbProg.generate_(rng, constraint, nested_model, μ, σ, shape)
7676

7777
@test trace.choices[:s][1] == fill(0.1, shape)
7878
@test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape)
@@ -108,33 +108,24 @@ end
108108

109109
constrained_addresses = ProbProg.extract_addresses(constraint1)
110110

111-
constraint_ptr1 = Reactant.ConcreteRNumber(
112-
reinterpret(UInt64, pointer_from_objref(constraint1))
111+
compiled_fn = @compile optimize = :probprog ProbProg.generate(
112+
rng, constraint1, model, μ, σ, shape; constrained_addresses
113113
)
114114

115-
wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate(
116-
rng, model, μ, σ, shape; constraint_ptr, constrained_addresses
117-
)
118-
119-
compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ)
120-
121115
trace1 = nothing
122116
seed_buffer = only(rng.seed.data).buffer
123117
GC.@preserve seed_buffer constraint1 begin
124-
trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ)
125-
trace1 = ProbProg.from_trace_tensor(trace1)
118+
trace1, _ = compiled_fn(rng, constraint1, model, μ, σ, shape)
119+
trace1 = ProbProg.ProbProgTrace(trace1)
126120
end
127121

128122
constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),))
129-
constraint_ptr2 = Reactant.ConcreteRNumber(
130-
reinterpret(UInt64, pointer_from_objref(constraint2))
131-
)
132123

133124
trace2 = nothing
134125
seed_buffer = only(rng.seed.data).buffer
135126
GC.@preserve seed_buffer constraint2 begin
136-
trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ)
137-
trace2 = ProbProg.from_trace_tensor(trace2)
127+
trace2, _ = compiled_fn(rng, constraint2, model, μ, σ, shape)
128+
trace2 = ProbProg.ProbProgTrace(trace2)
138129
end
139130

140131
@test trace1.choices[:s][1] != trace2.choices[:s][1]

test/probprog/hmc.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,10 @@ function hmc_program(
3636
num_steps,
3737
mass,
3838
initial_momentum,
39-
constraint_ptr,
39+
constraint,
4040
constrained_addresses,
4141
)
42-
t, _, _ = ProbProg.generate(
43-
rng,
44-
model,
45-
xs;
46-
constraint_ptr=constraint_ptr,
47-
constrained_addresses=constrained_addresses,
48-
)
42+
t, _, _ = ProbProg.generate(rng, constraint, model, xs; constrained_addresses)
4943

5044
t, accepted, _ = ProbProg.hmc(
5145
rng,
@@ -73,7 +67,6 @@ end
7367
:param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,)
7468
)
7569
constrained_addresses = ProbProg.extract_addresses(obs)
76-
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs)))
7770

7871
step_size = ConcreteRNumber(0.001)
7972
num_steps_compile = ConcreteRNumber(1000)
@@ -89,7 +82,7 @@ end
8982
num_steps_compile,
9083
mass,
9184
initial_momentum,
92-
constraint_ptr,
85+
obs,
9386
constrained_addresses,
9487
)
9588
@test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace")
@@ -106,7 +99,7 @@ end
10699
num_steps_compile,
107100
mass,
108101
initial_momentum,
109-
constraint_ptr,
102+
obs,
110103
constrained_addresses,
111104
)
112105
end
@@ -116,18 +109,18 @@ end
116109
trace = nothing
117110
GC.@preserve seed_buffer obs begin
118111
run_time_s = @elapsed begin
119-
trace_ptr, _ = compiled_fn(
112+
trace, _ = compiled_fn(
120113
rng,
121114
model,
122115
xs,
123116
step_size,
124117
num_steps_run,
125118
mass,
126119
initial_momentum,
127-
constraint_ptr,
120+
obs,
128121
constrained_addresses,
129122
)
130-
trace = ProbProg.from_trace_tensor(trace_ptr)
123+
trace = ProbProg.ProbProgTrace(trace)
131124
end
132125
println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms")
133126
end

0 commit comments

Comments
 (0)