Skip to content

Commit a4bda2d

Browse files
committed
inv mass matrices instead
1 parent 2d34ca5 commit a4bda2d

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

src/probprog/HMC.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function hmc(
66
f::Function,
77
args::Vararg{Any,Nargs};
88
selection::Selection,
9-
mass=nothing,
9+
inverse_mass_matrix=nothing,
1010
step_size=nothing,
1111
num_steps=nothing,
1212
initial_momentum=nothing,
@@ -53,9 +53,9 @@ function hmc(
5353
0::Int32, # 0 = HMC
5454
)::MLIR.IR.Attribute
5555

56-
mass_val = nothing
57-
if !isnothing(mass)
58-
mass_val = TracedUtils.get_mlir_data(mass)
56+
inverse_mass_matrix_val = nothing
57+
if !isnothing(inverse_mass_matrix)
58+
inverse_mass_matrix_val = TracedUtils.get_mlir_data(inverse_mass_matrix)
5959
end
6060

6161
step_size_val = nothing
@@ -76,7 +76,7 @@ function hmc(
7676
hmc_op = MLIR.Dialects.enzyme.mcmc(
7777
mlir_caller_args,
7878
trace_val,
79-
mass_val;
79+
inverse_mass_matrix_val;
8080
step_size=step_size_val,
8181
num_steps=num_steps_val,
8282
initial_momentum=initial_momentum_val,

test/probprog/hmc.jl

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Reactant, Test, Random
22
using Statistics
3-
using Reactant: ProbProg, ReactantRNG
3+
using Reactant: ProbProg, ReactantRNG, Profiler
44

55
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
66

@@ -34,7 +34,7 @@ function hmc_program(
3434
xs,
3535
step_size,
3636
num_steps,
37-
mass,
37+
inverse_mass_matrix,
3838
initial_momentum,
3939
constraint,
4040
constrained_addresses,
@@ -47,10 +47,10 @@ function hmc_program(
4747
model,
4848
xs;
4949
selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)),
50-
mass=mass,
51-
step_size=step_size,
52-
num_steps=num_steps,
53-
initial_momentum=initial_momentum,
50+
inverse_mass_matrix,
51+
step_size,
52+
num_steps,
53+
initial_momentum,
5454
)
5555

5656
return t, accepted
@@ -71,7 +71,7 @@ end
7171
step_size = ConcreteRNumber(0.001)
7272
num_steps_compile = ConcreteRNumber(1000)
7373
num_steps_run = ConcreteRNumber(40000000)
74-
mass = nothing
74+
inverse_mass_matrix = ConcreteRArray([1.0 0.0; 0.0 1.0])
7575
initial_momentum = ConcreteRArray([0.0, 0.0])
7676

7777
code = @code_hlo optimize = :probprog hmc_program(
@@ -80,7 +80,7 @@ end
8080
xs,
8181
step_size,
8282
num_steps_compile,
83-
mass,
83+
inverse_mass_matrix,
8484
initial_momentum,
8585
obs,
8686
constrained_addresses,
@@ -96,7 +96,7 @@ end
9696
xs,
9797
step_size,
9898
num_steps_compile,
99-
mass,
99+
inverse_mass_matrix,
100100
initial_momentum,
101101
obs,
102102
constrained_addresses,
@@ -106,19 +106,37 @@ end
106106

107107
seed_buffer = only(rng.seed.data).buffer
108108
trace = nothing
109+
enable_profiling = false
110+
109111
GC.@preserve seed_buffer obs begin
110112
run_time_s = @elapsed begin
111-
trace, _ = compiled_fn(
112-
rng,
113-
model,
114-
xs,
115-
step_size,
116-
num_steps_run,
117-
mass,
118-
initial_momentum,
119-
obs,
120-
constrained_addresses,
121-
)
113+
if enable_profiling
114+
Profiler.with_profiler("./traces"; create_perfetto_link=true) do
115+
trace, _ = compiled_fn(
116+
rng,
117+
model,
118+
xs,
119+
step_size,
120+
num_steps_run,
121+
inverse_mass_matrix,
122+
initial_momentum,
123+
obs,
124+
constrained_addresses,
125+
)
126+
end
127+
else
128+
trace, _ = compiled_fn(
129+
rng,
130+
model,
131+
xs,
132+
step_size,
133+
num_steps_run,
134+
inverse_mass_matrix,
135+
initial_momentum,
136+
obs,
137+
constrained_addresses,
138+
)
139+
end
122140
trace = ProbProg.ProbProgTrace(trace)
123141
end
124142
println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms")

0 commit comments

Comments
 (0)