11using Reactant, Test, Random
22using Statistics
3- using Reactant: ProbProg, ReactantRNG
3+ using Reactant: ProbProg, ReactantRNG, Profiler
44
55normal (rng, μ, σ, shape) = μ .+ σ .* randn (rng, shape)
66
@@ -34,7 +34,7 @@ function hmc_program(
3434 xs,
3535 step_size,
3636 num_steps,
37- mass ,
37+ inverse_mass_matrix ,
3838 initial_momentum,
3939 constraint,
4040 constrained_addresses,
@@ -47,10 +47,10 @@ function hmc_program(
4747 model,
4848 xs;
4949 selection= ProbProg. select (ProbProg. Address (:param_a ), ProbProg. Address (:param_b )),
50- mass = mass ,
51- step_size= step_size ,
52- num_steps= num_steps ,
53- initial_momentum= initial_momentum ,
50+ inverse_mass_matrix ,
51+ step_size,
52+ num_steps,
53+ initial_momentum,
5454 )
5555
5656 return t, accepted
7171 step_size = ConcreteRNumber (0.001 )
7272 num_steps_compile = ConcreteRNumber (1000 )
7373 num_steps_run = ConcreteRNumber (40000000 )
74- mass = nothing
74+ inverse_mass_matrix = ConcreteRArray ([ 1.0 0.0 ; 0.0 1.0 ])
7575 initial_momentum = ConcreteRArray ([0.0 , 0.0 ])
7676
7777 code = @code_hlo optimize = :probprog hmc_program (
8080 xs,
8181 step_size,
8282 num_steps_compile,
83- mass ,
83+ inverse_mass_matrix ,
8484 initial_momentum,
8585 obs,
8686 constrained_addresses,
9696 xs,
9797 step_size,
9898 num_steps_compile,
99- mass ,
99+ inverse_mass_matrix ,
100100 initial_momentum,
101101 obs,
102102 constrained_addresses,
@@ -106,19 +106,37 @@ end
106106
107107 seed_buffer = only (rng. seed. data). buffer
108108 trace = nothing
109+ enable_profiling = false
110+
109111 GC. @preserve seed_buffer obs begin
110112 run_time_s = @elapsed begin
111- trace, _ = compiled_fn (
112- rng,
113- model,
114- xs,
115- step_size,
116- num_steps_run,
117- mass,
118- initial_momentum,
119- obs,
120- constrained_addresses,
121- )
113+ if enable_profiling
114+ Profiler. with_profiler (" ./traces" ; create_perfetto_link= true ) do
115+ trace, _ = compiled_fn (
116+ rng,
117+ model,
118+ xs,
119+ step_size,
120+ num_steps_run,
121+ inverse_mass_matrix,
122+ initial_momentum,
123+ obs,
124+ constrained_addresses,
125+ )
126+ end
127+ else
128+ trace, _ = compiled_fn (
129+ rng,
130+ model,
131+ xs,
132+ step_size,
133+ num_steps_run,
134+ inverse_mass_matrix,
135+ initial_momentum,
136+ obs,
137+ constrained_addresses,
138+ )
139+ end
122140 trace = ProbProg. ProbProgTrace (trace)
123141 end
124142 println (" HMC run time: $(round (run_time_s * 1000 , digits= 2 )) ms" )
0 commit comments