Skip to content

Commit 250e81a

Browse files
jessegrabowskiJesse Grabowski
andauthored
Prefer mT over T in statespace equations (#595)
* Use `.mT` in `quad_form_sym` * Use `.mT` everywhere * pre-commit --------- Co-authored-by: Jesse Grabowski <jesse.grabowski@readyx.com>
1 parent 33a6c50 commit 250e81a

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
393393
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
394394
2nd ed, Oxford University Press, 2012.
395395
"""
396-
a_hat = T.dot(a) + c
396+
a_hat = T @ a + c
397397
P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q)
398398

399399
return a_hat, P_hat
@@ -580,16 +580,16 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
580580
.. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
581581
2nd ed, Oxford University Press, 2012.
582582
"""
583-
y_hat = d + Z.dot(a)
583+
y_hat = d + Z @ a
584584
v = y - y_hat
585585

586-
PZT = P.dot(Z.T)
586+
PZT = P.dot(Z.mT)
587587
F = Z.dot(PZT) + stabilize(H, self.cov_jitter)
588588

589-
K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T
589+
K = pt.linalg.solve(F.mT, PZT.mT, assume_a="pos", check_finite=False).mT
590590
I_KZ = pt.eye(self.n_states) - K.dot(Z)
591591

592-
a_filtered = a + K.dot(v)
592+
a_filtered = a + K @ v
593593
P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H)
594594

595595
F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False)
@@ -630,9 +630,9 @@ def predict(self, a, P, c, T, R, Q):
630630
a_hat = T.dot(a) + c
631631
Q_chol = pt.linalg.cholesky(Q, lower=True)
632632

633-
M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T
633+
M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).mT
634634
R_decomp = pt.linalg.qr(M, mode="r")
635-
P_chol_hat = R_decomp[: self.n_states, : self.n_states].T
635+
P_chol_hat = R_decomp[..., : self.n_states, : self.n_states].mT
636636

637637
return a_hat, P_chol_hat
638638

@@ -665,7 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
665665
upper = pt.horizontal_stack(H_chol, Z @ P_chol)
666666
lower = pt.horizontal_stack(zeros, P_chol)
667667
A_T = pt.vertical_stack(upper, lower)
668-
B = pt.linalg.qr(A_T.T, mode="r").T
668+
B = pt.linalg.qr(A_T.mT, mode="r").mT
669669

670670
F_chol = B[: self.n_endog, : self.n_endog]
671671
K_F_chol = B[self.n_endog :, : self.n_endog]
@@ -677,6 +677,7 @@ def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
677677
inner_term = solve_triangular(
678678
F_chol, solve_triangular(F_chol, v, lower=True), lower=True
679679
)
680+
680681
loss = (v.T @ inner_term).ravel()
681682

682683
# abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
@@ -800,7 +801,7 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
800801
obs_cov[-1],
801802
)
802803

803-
P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter)
804+
P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter)
804805
a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q)
805806

806807
ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum())

pymc_extras/statespace/filters/utilities.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import pytensor.tensor as pt
22

3-
from pytensor.tensor.nlinalg import matrix_dot
4-
53
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED
64

75

@@ -48,12 +46,11 @@ def split_vars_into_seq_and_nonseq(params, param_names):
4846

4947

5048
def stabilize(cov, jitter=JITTER_DEFAULT):
51-
# Ensure diagonal is non-zero
5249
cov = cov + pt.identity_like(cov) * jitter
5350

5451
return cov
5552

5653

5754
def quad_form_sym(A, B):
58-
out = matrix_dot(A, B, A.T)
59-
return 0.5 * (out + out.T)
55+
out = A @ B @ A.mT
56+
return 0.5 * (out + out.mT)

0 commit comments

Comments
 (0)