3636import numpy as np
3737from numpy .testing import assert_equal , assert_almost_equal
3838
39+ from scipy .linalg import sqrtm
40+
3941from qinfer .tests .base_test import DerandomizedTestCase , MockModel , assert_warns
4042
41- from qinfer .utils import in_ellipsoid , assert_sigfigs_equal
43+ from qinfer .utils import in_ellipsoid , assert_sigfigs_equal , sqrtm_psd
4244
4345## TESTS #####################################################################
4446
@@ -80,7 +82,7 @@ def test_assert_sigfigs_equal(self):
8082 np .array ([1728 ]),
8183 4
8284 )
83-
85+
8486
8587class TestEllipsoids (DerandomizedTestCase ):
8688
@@ -116,4 +118,26 @@ def test_in_ellipsoid(self):
116118 assert_equal (
117119 in_ellipsoid (x , A , c ),
118120 np .array ([1 ,0 ,1 ,0 ], dtype = bool )
119- )
121+ )
122+
123+ class TestLinearAlgebra (DerandomizedTestCase ):
124+ def test_sqrtm_psd (self ):
125+ # Construct Y = XX^T as a PSD matrix.
126+ X = np .random .random ((5 , 5 ))
127+ Y = np .dot (X , X .T )
128+ sqrt_Y = sqrtm_psd (Y , est_error = False )
129+
130+ np .testing .assert_allclose (
131+ np .dot (sqrt_Y , sqrt_Y ),
132+ Y
133+ )
134+
135+ # Try again, but with a singular matrix.
136+ Y_singular = np .zeros ((6 , 6 ))
137+ Y_singular [:5 , :5 ] = Y
138+ sqrt_Y_singular = sqrtm_psd (Y_singular , est_error = False )
139+
140+ np .testing .assert_allclose (
141+ np .dot (sqrt_Y_singular , sqrt_Y_singular ),
142+ Y_singular
143+ )
0 commit comments