|
14 | 14 | from pytensor.graph.basic import graph_inputs |
15 | 15 |
|
16 | 16 | from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace |
| 17 | +from pymc_extras.statespace.filters.kalman_filter import StandardFilter |
| 18 | +from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother |
17 | 19 | from pymc_extras.statespace.models import structural as st |
18 | 20 | from pymc_extras.statespace.models.utilities import make_default_coords |
19 | 21 | from pymc_extras.statespace.utils.constants import ( |
@@ -1025,3 +1027,47 @@ def test_insert_batched_rvs(ss_mod, batch_size): |
1025 | 1027 | ss_mod._insert_random_variables() |
1026 | 1028 | matrices = ss_mod.unpack_statespace() |
1027 | 1029 | assert matrices[4].type.shape == (*batch_size, 2, 2) |
| 1030 | + |
| 1031 | + |
| 1032 | +@pytest.mark.parametrize("batch_size", [(10,), (10, 3, 5)]) |
| 1033 | +def test_insert_batched_rvs_in_kf(ss_mod, batch_size): |
| 1034 | + data = pt.as_tensor(np.random.normal(size=(*batch_size, 7, 1)).astype(floatX)) |
| 1035 | + data.name = "data" |
| 1036 | + kf = StandardFilter() |
| 1037 | + |
| 1038 | + with pm.Model(): |
| 1039 | + rho = pm.Normal("rho", shape=batch_size) |
| 1040 | + zeta = pm.Normal("zeta", shape=batch_size) |
| 1041 | + ss_mod._insert_random_variables() |
| 1042 | + |
| 1043 | + matrices = x0, P0, c, d, T, Z, R, H, Q = ss_mod.unpack_statespace() |
| 1044 | + outputs = kf.build_graph(data, *matrices) |
| 1045 | + |
| 1046 | + logp = outputs.pop(-1) |
| 1047 | + states, covs = outputs[:3], outputs[3:] |
| 1048 | + filtered_states, predicted_states, observed_states = states |
| 1049 | + filtered_covariances, predicted_covariances, observed_covariances = covs |
| 1050 | + |
| 1051 | + assert logp.type.shape == (*batch_size, 7) |
| 1052 | + assert filtered_states.type.shape == (*batch_size, 7, 2) |
| 1053 | + assert predicted_states.type.shape == (*batch_size, 7, 2) |
| 1054 | + assert observed_states.type.shape == (*batch_size, 7, 1) |
| 1055 | + assert filtered_covariances.type.shape == (*batch_size, 7, 2, 2) |
| 1056 | + assert predicted_covariances.type.shape == (*batch_size, 7, 2, 2) |
| 1057 | + assert observed_covariances.type.shape == (*batch_size, 7, 1, 1) |
| 1058 | + |
| 1059 | + ks = KalmanSmoother() |
| 1060 | + smoothed_states, smoothed_covariances = ks.build_graph( |
| 1061 | + T, R, Q, filtered_states, filtered_covariances |
| 1062 | + ) |
| 1063 | + assert smoothed_states.type.shape == ( |
| 1064 | + *batch_size, |
| 1065 | + None, |
| 1066 | + 2, |
| 1067 | + ) # TODO: why do we lose the time dimension here? |
| 1068 | + assert smoothed_covariances.type.shape == ( |
| 1069 | + *batch_size, |
| 1070 | + None, |
| 1071 | + 2, |
| 1072 | + 2, |
| 1073 | + ) # TODO: why do we lose the time dimension here? |
0 commit comments