|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import sys |
| 16 | +import warnings |
| 17 | + |
| 18 | +import aesara.tensor as at |
| 19 | +import numpy as np |
15 | 20 | import pytest |
16 | 21 |
|
| 22 | +import pymc as pm |
| 23 | + |
| 24 | +from pymc.aesaraf import floatX |
| 25 | +from pymc.exceptions import SamplingError |
| 26 | +from pymc.step_methods.hmc import NUTS |
17 | 27 | from pymc.tests import sampler_fixtures as sf |
| 28 | +from pymc.tests.helpers import RVsAssignmentStepsTester, StepMethodTester |
18 | 29 |
|
19 | 30 |
|
20 | 31 | class TestNUTSUniform(sf.NutsFixture, sf.UniformFixture): |
@@ -81,3 +92,116 @@ class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture): |
81 | 92 | burn = 0 |
82 | 93 | chains = 2 |
83 | 94 | min_n_eff = 200 |
| 95 | + |
| 96 | + |
| 97 | +class TestNutsCheckTrace: |
| 98 | + def test_multiple_samplers(self, caplog): |
| 99 | + with pm.Model(): |
| 100 | + prob = pm.Beta("prob", alpha=5.0, beta=3.0) |
| 101 | + pm.Binomial("outcome", n=1, p=prob) |
| 102 | + caplog.clear() |
| 103 | + with warnings.catch_warnings(): |
| 104 | + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) |
| 105 | + pm.sample(3, tune=2, discard_tuned_samples=False, n_init=None, chains=1) |
| 106 | + messages = [msg.msg for msg in caplog.records] |
| 107 | + assert all("boolean index did not" not in msg for msg in messages) |
| 108 | + |
| 109 | + def test_bad_init_nonparallel(self): |
| 110 | + with pm.Model(): |
| 111 | + pm.HalfNormal("a", sigma=1, initval=-1, transform=None) |
| 112 | + with pytest.raises(SamplingError) as error: |
| 113 | + pm.sample(chains=1, random_seed=1) |
| 114 | + error.match("Initial evaluation") |
| 115 | + |
| 116 | + @pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") |
| 117 | + def test_bad_init_parallel(self): |
| 118 | + with pm.Model(): |
| 119 | + pm.HalfNormal("a", sigma=1, initval=-1, transform=None) |
| 120 | + with pytest.raises(SamplingError) as error: |
| 121 | + pm.sample(cores=2, random_seed=1) |
| 122 | + error.match("Initial evaluation") |
| 123 | + |
| 124 | + def test_linalg(self, caplog): |
| 125 | + with pm.Model(): |
| 126 | + a = pm.Normal("a", size=2, initval=floatX(np.zeros(2))) |
| 127 | + a = at.switch(a > 0, np.inf, a) |
| 128 | + b = at.slinalg.solve(floatX(np.eye(2)), a, check_finite=False) |
| 129 | + pm.Normal("c", mu=b, size=2, initval=floatX(np.r_[0.0, 0.0])) |
| 130 | + caplog.clear() |
| 131 | + with warnings.catch_warnings(): |
| 132 | + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) |
| 133 | + trace = pm.sample(20, tune=5, chains=2, return_inferencedata=False, random_seed=526) |
| 134 | + warns = [msg.msg for msg in caplog.records] |
| 135 | + assert np.any(trace["diverging"]) |
| 136 | + assert ( |
| 137 | + any("divergence after tuning" in warn for warn in warns) |
| 138 | + or any("divergences after tuning" in warn for warn in warns) |
| 139 | + or any("only diverging samples" in warn for warn in warns) |
| 140 | + ) |
| 141 | + |
| 142 | + with pytest.raises(ValueError) as error: |
| 143 | + trace.report.raise_ok() |
| 144 | + error.match("issues during sampling") |
| 145 | + |
| 146 | + assert not trace.report.ok |
| 147 | + |
| 148 | + def test_sampler_stats(self): |
| 149 | + with pm.Model() as model: |
| 150 | + pm.Normal("x", mu=0, sigma=1) |
| 151 | + with warnings.catch_warnings(): |
| 152 | + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) |
| 153 | + trace = pm.sample(draws=10, tune=1, chains=1, return_inferencedata=False) |
| 154 | + |
| 155 | + # Assert stats exist and have the correct shape. |
| 156 | + expected_stat_names = { |
| 157 | + "depth", |
| 158 | + "diverging", |
| 159 | + "energy", |
| 160 | + "energy_error", |
| 161 | + "model_logp", |
| 162 | + "max_energy_error", |
| 163 | + "mean_tree_accept", |
| 164 | + "step_size", |
| 165 | + "step_size_bar", |
| 166 | + "tree_size", |
| 167 | + "tune", |
| 168 | + "perf_counter_diff", |
| 169 | + "perf_counter_start", |
| 170 | + "process_time_diff", |
| 171 | + "index_in_trajectory", |
| 172 | + "largest_eigval", |
| 173 | + "smallest_eigval", |
| 174 | + } |
| 175 | + assert trace.stat_names == expected_stat_names |
| 176 | + for varname in trace.stat_names: |
| 177 | + assert trace.get_sampler_stats(varname).shape == (10,) |
| 178 | + |
| 179 | + # Assert model logp is computed correctly: computing post-sampling |
| 180 | + # and tracking while sampling should give same results. |
| 181 | + model_logp_fn = model.compile_logp() |
| 182 | + model_logp_ = np.array( |
| 183 | + [ |
| 184 | + model_logp_fn(trace.point(i, chain=c)) |
| 185 | + for c in trace.chains |
| 186 | + for i in range(len(trace)) |
| 187 | + ] |
| 188 | + ) |
| 189 | + assert (trace.model_logp == model_logp_).all() |
| 190 | + |
| 191 | + |
| 192 | +class TestStepNUTS(StepMethodTester): |
| 193 | + @pytest.mark.parametrize( |
| 194 | + "step_fn, draws", |
| 195 | + [ |
| 196 | + (lambda C, _: NUTS(scaling=C, is_cov=True, blocked=False), 1000), |
| 197 | + (lambda C, _: NUTS(scaling=C, is_cov=True), 1000), |
| 198 | + ], |
| 199 | + ) |
| 200 | + def test_step_continuous(self, step_fn, draws): |
| 201 | + self.step_continuous(step_fn, draws) |
| 202 | + |
| 203 | + |
| 204 | +class TestRVsAssignmentNUTS(RVsAssignmentStepsTester): |
| 205 | + @pytest.mark.parametrize("step, step_kwargs", [(NUTS, {})]) |
| 206 | + def test_continuous_steps(self, step, step_kwargs): |
| 207 | + self.continuous_steps(step, step_kwargs) |
0 commit comments