|
25 | 25 | import jax.numpy as jnp |
26 | 26 | import numpy as np |
27 | 27 |
|
| 28 | +from blackjax.smc import extend_params |
28 | 29 | from blackjax.smc.resampling import systematic |
29 | 30 | from pymc import draw, modelcontext, to_inference_data |
30 | 31 | from pymc.backends import NDArray |
@@ -126,16 +127,20 @@ def sample_smc_blackjax( |
126 | 127 |
|
127 | 128 | if kernel == "HMC": |
128 | 129 | mcmc_kernel = blackjax.mcmc.hmc |
129 | | - mcmc_parameters = dict( |
130 | | - step_size=inner_kernel_params["step_size"], |
131 | | - inverse_mass_matrix=jnp.eye(posterior_dimensions), |
132 | | - num_integration_steps=inner_kernel_params["integration_steps"], |
| 130 | + mcmc_parameters = extend_params( |
| 131 | + dict( |
| 132 | + step_size=inner_kernel_params["step_size"], |
| 133 | + inverse_mass_matrix=jnp.eye(posterior_dimensions), |
| 134 | + num_integration_steps=inner_kernel_params["integration_steps"], |
| 135 | + ) |
133 | 136 | ) |
134 | 137 | elif kernel == "NUTS": |
135 | 138 | mcmc_kernel = blackjax.mcmc.nuts |
136 | | - mcmc_parameters = dict( |
137 | | - step_size=inner_kernel_params["step_size"], |
138 | | - inverse_mass_matrix=jnp.eye(posterior_dimensions), |
| 139 | + mcmc_parameters = extend_params( |
| 140 | + dict( |
| 141 | + step_size=inner_kernel_params["step_size"], |
| 142 | + inverse_mass_matrix=jnp.eye(posterior_dimensions), |
| 143 | + ) |
139 | 144 | ) |
140 | 145 | else: |
141 | 146 | raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'") |
|
0 commit comments