22import typing
33import warnings
44from functools import reduce
5- from typing import TYPE_CHECKING , Literal , cast
5+ from typing import Literal , cast
66
77import numpy as np
88import scipy .linalg
1111import pytensor .tensor as pt
1212from pytensor .graph .basic import Apply
1313from pytensor .graph .op import Op
14- from pytensor .tensor import as_tensor_variable
14+ from pytensor .tensor import TensorLike , as_tensor_variable
1515from pytensor .tensor import basic as ptb
1616from pytensor .tensor import math as ptm
1717from pytensor .tensor .blockwise import Blockwise
2121from pytensor .tensor .variable import TensorVariable
2222
2323
24- if TYPE_CHECKING :
25- from pytensor .tensor import TensorLike
26-
2724logger = logging .getLogger (__name__ )
2825
2926
@@ -777,7 +774,16 @@ def perform(self, node, inputs, outputs):
777774
778775
779776class SolveContinuousLyapunov (Op ):
777+ """
778+ Solves a continuous Lyapunov equation, :math:`AX + XA^H = B`, for :math:`X.
779+
780+ Continuous time Lyapunov equations are special cases of Sylvester equations, :math:`AX + XB = C`, and can be solved
781+ efficiently using the Bartels-Stewart algorithm. For more details, see the docstring for
782+ scipy.linalg.solve_continuous_lyapunov
783+ """
784+
780785 __props__ = ()
786+ gufunc_signature = "(m,m),(m,m)->(m,m)"
781787
782788 def make_node (self , A , B ):
783789 A = as_tensor_variable (A )
@@ -792,7 +798,8 @@ def perform(self, node, inputs, output_storage):
792798 (A , B ) = inputs
793799 X = output_storage [0 ]
794800
795- X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B )
801+ out_dtype = node .outputs [0 ].type .dtype
802+ X [0 ] = scipy .linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
796803
797804 def infer_shape (self , fgraph , node , shapes ):
798805 return [shapes [0 ]]
@@ -813,7 +820,41 @@ def grad(self, inputs, output_grads):
813820 return [A_bar , Q_bar ]
814821
815822
823+ _solve_continuous_lyapunov = Blockwise (SolveContinuousLyapunov ())
824+
825+
826+ def solve_continuous_lyapunov (A : TensorLike , Q : TensorLike ) -> TensorVariable :
827+ """
828+ Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
829+
830+ Parameters
831+ ----------
832+ A: TensorLike
833+ Square matrix of shape ``N x N``.
834+ Q: TensorLike
835+ Square matrix of shape ``N x N``.
836+
837+ Returns
838+ -------
839+ X: TensorVariable
840+ Square matrix of shape ``N x N``
841+
842+ """
843+
844+ return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
845+
846+
816847class BilinearSolveDiscreteLyapunov (Op ):
848+ """
849+ Solves a discrete lyapunov equation, :math:`AXA^H - X = Q`, for :math:`X.
850+
851+ The solution is computed by first transforming the discrete-time problem into a continuous-time form. The continuous
852+ time lyapunov is a special case of a Sylvester equation, and can be efficiently solved. For more details, see the
853+ docstring for scipy.linalg.solve_discrete_lyapunov
854+ """
855+
856+ gufunc_signature = "(m,m),(m,m)->(m,m)"
857+
817858 def make_node (self , A , B ):
818859 A = as_tensor_variable (A )
819860 B = as_tensor_variable (B )
@@ -827,7 +868,10 @@ def perform(self, node, inputs, output_storage):
827868 (A , B ) = inputs
828869 X = output_storage [0 ]
829870
830- X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" )
871+ out_dtype = node .outputs [0 ].type .dtype
872+ X [0 ] = scipy .linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
873+ out_dtype
874+ )
831875
832876 def infer_shape (self , fgraph , node , shapes ):
833877 return [shapes [0 ]]
@@ -849,83 +893,83 @@ def grad(self, inputs, output_grads):
849893 return [A_bar , Q_bar ]
850894
851895
852- _solve_continuous_lyapunov = SolveContinuousLyapunov ()
853- _solve_bilinear_direct_lyapunov = cast (typing .Callable , BilinearSolveDiscreteLyapunov ())
896+ _bilinear_solve_discrete_lyapunov = Blockwise (BilinearSolveDiscreteLyapunov ())
854897
855898
856- def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
857- A_ = as_tensor_variable (A )
858- Q_ = as_tensor_variable (Q )
899+ def _direct_solve_discrete_lyapunov (
900+ A : TensorVariable , Q : TensorVariable
901+ ) -> TensorVariable :
902+ r"""
903+ Directly solve the discrete Lyapunov equation :math:`A X A^H - X = Q` using the kronecker method of Magnus and
904+ Neudecker.
905+
906+ This involves constructing and inverting an intermediate matrix :math:`A \otimes A`, with shape :math:`N^2 x N^2`.
907+ As a result, this method scales poorly with the size of :math:`N`, and should be avoided for large :math:`N`.
908+ """
859909
860- if "complex" in A_ .type .dtype :
861- AA = kron (A_ , A_ .conj ())
910+ if A .type .dtype . startswith ( "complex" ) :
911+ AxA = kron (A , A .conj ())
862912 else :
863- AA = kron (A_ , A_ )
913+ AxA = kron (A , A )
914+
915+ eye = pt .eye (AxA .shape [- 1 ])
864916
865- X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
866- return cast (TensorVariable , reshape (X , Q_ .shape ))
917+ vec_Q = Q .ravel ()
918+ vec_X = solve (eye - AxA , vec_Q , b_ndim = 1 )
919+
920+ return cast (TensorVariable , reshape (vec_X , A .shape ))
867921
868922
869923def solve_discrete_lyapunov (
870- A : "TensorLike" , Q : "TensorLike" , method : Literal ["direct" , "bilinear" ] = "direct"
924+ A : TensorLike ,
925+ Q : TensorLike ,
926+ method : Literal ["direct" , "bilinear" ] = "bilinear" ,
871927) -> TensorVariable :
872928 """Solve the discrete Lyapunov equation :math:`A X A^H - X = Q`.
873929
874930 Parameters
875931 ----------
876- A
877- Square matrix of shape N x N; must have the same shape as Q
878- Q
879- Square matrix of shape N x N; must have the same shape as A
880- method
881- Solver method used, one of ``"direct"`` or ``"bilinear"``. ``"direct"``
882- solves the problem directly via matrix inversion. This has a pure
883- PyTensor implementation and can thus be cross-compiled to supported
884- backends, and should be preferred when ``N`` is not large. The direct
885- method scales poorly with the size of ``N``, and the bilinear can be
932+ A: TensorLike
933+ Square matrix of shape N x N
934+ Q: TensorLike
935+ Square matrix of shape N x N
936+ method: str, one of ``"direct"`` or ``"bilinear"``
937+ Solver method used, . ``"direct"`` solves the problem directly via matrix inversion. This has a pure
938+ PyTensor implementation and can thus be cross-compiled to supported backends, and should be preferred when
939+ ``N`` is not large. The direct method scales poorly with the size of ``N``, and the bilinear can be
886940 used in these cases.
887941
888942 Returns
889943 -------
890- Square matrix of shape ``N x N``, representing the solution to the
891- Lyapunov equation
944+ X: TensorVariable
945+ Square matrix of shape ``N x N``. Solution to the Lyapunov equation
892946
893947 """
894948 if method not in ["direct" , "bilinear" ]:
895949 raise ValueError (
896950 f'Parameter "method" must be one of "direct" or "bilinear", found { method } '
897951 )
898952
899- if method == "direct" :
900- return _direct_solve_discrete_lyapunov (A , Q )
901- if method == "bilinear" :
902- return cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
903-
904-
905- def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
906- """Solve the continuous Lyapunov equation :math:`A X + X A^H + Q = 0`.
907-
908- Parameters
909- ----------
910- A
911- Square matrix of shape ``N x N``; must have the same shape as `Q`.
912- Q
913- Square matrix of shape ``N x N``; must have the same shape as `A`.
953+ A = as_tensor_variable (A )
954+ Q = as_tensor_variable (Q )
914955
915- Returns
916- -------
917- Square matrix of shape ``N x N``, representing the solution to the
918- Lyapunov equation
956+ if method == "direct" :
957+ signature = BilinearSolveDiscreteLyapunov . gufunc_signature
958+ X = pt . vectorize ( _direct_solve_discrete_lyapunov , signature = signature )( A , Q )
959+ return cast ( TensorVariable , X )
919960
920- """
961+ elif method == "bilinear" :
962+ return cast (TensorVariable , _bilinear_solve_discrete_lyapunov (A , Q ))
921963
922- return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
964+ else :
965+ raise ValueError (f"Unknown method { method } " )
923966
924967
925- class SolveDiscreteARE (pt . Op ):
968+ class SolveDiscreteARE (Op ):
926969 __props__ = ("enforce_Q_symmetric" ,)
970+ gufunc_signature = "(m,m),(m,n),(m,m),(n,n)->(m,m)"
927971
928- def __init__ (self , enforce_Q_symmetric = False ):
972+ def __init__ (self , enforce_Q_symmetric : bool = False ):
929973 self .enforce_Q_symmetric = enforce_Q_symmetric
930974
931975 def make_node (self , A , B , Q , R ):
@@ -946,9 +990,8 @@ def perform(self, node, inputs, output_storage):
946990 if self .enforce_Q_symmetric :
947991 Q = 0.5 * (Q + Q .T )
948992
949- X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (
950- node .outputs [0 ].type .dtype
951- )
993+ out_dtype = node .outputs [0 ].type .dtype
994+ X [0 ] = scipy .linalg .solve_discrete_are (A , B , Q , R ).astype (out_dtype )
952995
953996 def infer_shape (self , fgraph , node , shapes ):
954997 return [shapes [0 ]]
@@ -960,14 +1003,16 @@ def grad(self, inputs, output_grads):
9601003 (dX ,) = output_grads
9611004 X = self (A , B , Q , R )
9621005
963- K_inner = R + pt .linalg .matrix_dot (B .T , X , B )
964- K_inner_inv = pt .linalg .solve (K_inner , pt .eye (R .shape [0 ]))
965- K = matrix_dot (K_inner_inv , B .T , X , A )
1006+ K_inner = R + matrix_dot (B .T , X , B )
1007+
1008+ # K_inner is guaranteed to be symmetric, because X and R are symmetric
1009+ K_inner_inv_BT = solve (K_inner , B .T , assume_a = "sym" )
1010+ K = matrix_dot (K_inner_inv_BT , X , A )
9661011
9671012 A_tilde = A - B .dot (K )
9681013
9691014 dX_symm = 0.5 * (dX + dX .T )
970- S = solve_discrete_lyapunov (A_tilde , dX_symm ). astype ( dX . type . dtype )
1015+ S = solve_discrete_lyapunov (A_tilde , dX_symm )
9711016
9721017 A_bar = 2 * matrix_dot (X , A_tilde , S )
9731018 B_bar = - 2 * matrix_dot (X , A_tilde , S , K .T )
@@ -977,30 +1022,45 @@ def grad(self, inputs, output_grads):
9771022 return [A_bar , B_bar , Q_bar , R_bar ]
9781023
9791024
980- def solve_discrete_are (A , B , Q , R , enforce_Q_symmetric = False ) -> TensorVariable :
1025+ def solve_discrete_are (
1026+ A : TensorLike ,
1027+ B : TensorLike ,
1028+ Q : TensorLike ,
1029+ R : TensorLike ,
1030+ enforce_Q_symmetric : bool = False ,
1031+ ) -> TensorVariable :
9811032 """
9821033 Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`.
9831034
1035+ Discrete-time Algebraic Riccati equations arise in the context of optimal control and filtering problems, as the
1036+ solution to Linear-Quadratic Regulators (LQR), Linear-Quadratic-Guassian (LQG) control problems, and as the
1037+ steady-state covariance of the Kalman Filter.
1038+
1039+ Such problems typically have many solutions, but we are generally only interested in the unique *stabilizing*
1040+ solution. This stable solution, if it exists, will be returned by this function.
1041+
9841042 Parameters
9851043 ----------
986- A: ArrayLike
1044+ A: TensorLike
9871045 Square matrix of shape M x M
988- B: ArrayLike
1046+ B: TensorLike
9891047 Square matrix of shape M x M
990- Q: ArrayLike
1048+ Q: TensorLike
9911049 Symmetric square matrix of shape M x M
992- R: ArrayLike
1050+ R: TensorLike
9931051 Square matrix of shape N x N
9941052 enforce_Q_symmetric: bool
9951053 If True, the provided Q matrix is transformed to 0.5 * (Q + Q.T) to ensure symmetry
9961054
9971055 Returns
9981056 -------
999- X: pt.matrix
1057+ X: TensorVariable
10001058 Square matrix of shape M x M, representing the solution to the DARE
10011059 """
10021060
1003- return cast (TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R ))
1061+ return cast (
1062+ TensorVariable , Blockwise (SolveDiscreteARE (enforce_Q_symmetric ))(A , B , Q , R )
1063+ )
10041064
10051065
10061066def _largest_common_dtype (tensors : typing .Sequence [TensorVariable ]) -> np .dtype :
0 commit comments