@@ -393,7 +393,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]:
393393 .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
394394 2nd ed, Oxford University Press, 2012.
395395 """
396- a_hat = T . dot ( a ) + c
396+ a_hat = T @ a + c
397397 P_hat = quad_form_sym (T , P ) + quad_form_sym (R , Q )
398398
399399 return a_hat , P_hat
@@ -580,16 +580,16 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
580580 .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods.
581581 2nd ed, Oxford University Press, 2012.
582582 """
583- y_hat = d + Z . dot ( a )
583+ y_hat = d + Z @ a
584584 v = y - y_hat
585585
586- PZT = P .dot (Z .T )
586+ PZT = P .dot (Z .mT )
587587 F = Z .dot (PZT ) + stabilize (H , self .cov_jitter )
588588
589- K = pt .linalg .solve (F .T , PZT .T , assume_a = "pos" , check_finite = False ).T
589+ K = pt .linalg .solve (F .mT , PZT .mT , assume_a = "pos" , check_finite = False ).mT
590590 I_KZ = pt .eye (self .n_states ) - K .dot (Z )
591591
592- a_filtered = a + K . dot ( v )
592+ a_filtered = a + K @ v
593593 P_filtered = quad_form_sym (I_KZ , P ) + quad_form_sym (K , H )
594594
595595 F_inv_v = pt .linalg .solve (F , v , assume_a = "pos" , check_finite = False )
@@ -630,9 +630,9 @@ def predict(self, a, P, c, T, R, Q):
630630 a_hat = T .dot (a ) + c
631631 Q_chol = pt .linalg .cholesky (Q , lower = True )
632632
633- M = pt .horizontal_stack (T @ P_chol , R @ Q_chol ).T
633+ M = pt .horizontal_stack (T @ P_chol , R @ Q_chol ).mT
634634 R_decomp = pt .linalg .qr (M , mode = "r" )
635- P_chol_hat = R_decomp [: self .n_states , : self .n_states ].T
635+ P_chol_hat = R_decomp [..., : self .n_states , : self .n_states ].mT
636636
637637 return a_hat , P_chol_hat
638638
@@ -665,7 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
665665 upper = pt .horizontal_stack (H_chol , Z @ P_chol )
666666 lower = pt .horizontal_stack (zeros , P_chol )
667667 A_T = pt .vertical_stack (upper , lower )
668- B = pt .linalg .qr (A_T .T , mode = "r" ).T
668+ B = pt .linalg .qr (A_T .mT , mode = "r" ).mT
669669
670670 F_chol = B [: self .n_endog , : self .n_endog ]
671671 K_F_chol = B [self .n_endog :, : self .n_endog ]
@@ -677,6 +677,7 @@ def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v):
677677 inner_term = solve_triangular (
678678 F_chol , solve_triangular (F_chol , v , lower = True ), lower = True
679679 )
680+
680681 loss = (v .T @ inner_term ).ravel ()
681682
682683 # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition
@@ -800,7 +801,7 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q):
800801 obs_cov [- 1 ],
801802 )
802803
803- P_filtered = stabilize (0.5 * (P_filtered + P_filtered .T ), self .cov_jitter )
804+ P_filtered = stabilize (0.5 * (P_filtered + P_filtered .mT ), self .cov_jitter )
804805 a_hat , P_hat = self .predict (a = a_filtered , P = P_filtered , c = c , T = T , R = R , Q = Q )
805806
806807 ll = - 0.5 * ((pt .neq (ll_inner , 0 ).sum ()) * MVN_CONST + ll_inner .sum ())
0 commit comments