1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import warnings
16+
17+ import numpy as np
18+ import numpy .testing as npt
1519import pytest
1620
21+ import pymc as pm
22+
23+ from pymc .aesaraf import floatX
24+ from pymc .blocking import DictToArrayBijection , RaveledVars
1725from pymc .step_methods .hmc import HamiltonianMC
26+ from pymc .step_methods .hmc .base_hmc import BaseHMC
27+ from pymc .tests import models
1828from pymc .tests .helpers import RVsAssignmentStepsTester , StepMethodTester
1929
2030
@@ -34,3 +44,48 @@ class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester):
3444 @pytest .mark .parametrize ("step, step_kwargs" , [(HamiltonianMC , {})])
3545 def test_continuous_steps (self , step , step_kwargs ):
3646 self .continuous_steps (step , step_kwargs )
47+
48+
49+ def test_leapfrog_reversible ():
50+ n = 3
51+ np .random .seed (42 )
52+ start , model , _ = models .non_normal (n )
53+ size = sum (start [n .name ].size for n in model .value_vars )
54+ scaling = floatX (np .random .rand (size ))
55+
56+ class HMC (BaseHMC ):
57+ def _hamiltonian_step (self , * args , ** kwargs ):
58+ pass
59+
60+ step = HMC (vars = model .value_vars , model = model , scaling = scaling )
61+
62+ step .integrator ._logp_dlogp_func .set_extra_values ({})
63+ astart = DictToArrayBijection .map (start )
64+ p = RaveledVars (floatX (step .potential .random ()), astart .point_map_info )
65+ q = RaveledVars (floatX (np .random .randn (size )), astart .point_map_info )
66+ start = step .integrator .compute_state (p , q )
67+ for epsilon in [0.01 , 0.1 ]:
68+ for n_steps in [1 , 2 , 3 , 4 , 20 ]:
69+ state = start
70+ for _ in range (n_steps ):
71+ state = step .integrator .step (epsilon , state )
72+ for _ in range (n_steps ):
73+ state = step .integrator .step (- epsilon , state )
74+ npt .assert_allclose (state .q .data , start .q .data , rtol = 1e-5 )
75+ npt .assert_allclose (state .p .data , start .p .data , rtol = 1e-5 )
76+
77+
78+ def test_nuts_tuning ():
79+ with pm .Model ():
80+ pm .Normal ("mu" , mu = 0 , sigma = 1 )
81+ step = pm .NUTS ()
82+ with warnings .catch_warnings ():
83+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
84+ idata = pm .sample (
85+ 10 , step = step , tune = 5 , discard_tuned_samples = False , progressbar = False , chains = 1
86+ )
87+
88+ assert not step .tune
89+ ss_tuned = idata .warmup_sample_stats ["step_size" ][0 , - 1 ]
90+ ss_posterior = idata .sample_stats ["step_size" ][0 , :]
91+ np .testing .assert_array_equal (ss_posterior , ss_tuned )
0 commit comments