Skip to content

Commit ad5be44

Browse files
authored
Merge pull request #110 from QInfer/fix-lw-cov-esti
Fix for #108.
2 parents 12b2612 + 560d5ff commit ad5be44

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

src/qinfer/resamplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import scipy.linalg as la
4343
import warnings
4444

45-
from .utils import outer_product, particle_meanfn, particle_covariance_mtx
45+
from .utils import outer_product, particle_meanfn, particle_covariance_mtx, sqrtm_psd
4646

4747
from abc import ABCMeta, abstractmethod, abstractproperty
4848
from future.utils import with_metaclass
@@ -260,7 +260,7 @@ def __call__(self, model, particle_weights, particle_locations,
260260
ResamplerWarning
261261
)
262262
cov = self._zero_cov_comp * np.eye(cov.shape[0])
263-
S, S_err = la.sqrtm(cov, disp=False)
263+
S, S_err = sqrtm_psd(cov)
264264
if not np.isfinite(S_err):
265265
raise ResamplerError(
266266
"Infinite error in computing the square root of the "

src/qinfer/tests/test_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
import numpy as np
3737
from numpy.testing import assert_equal, assert_almost_equal
3838

39+
from scipy.linalg import sqrtm
40+
3941
from qinfer.tests.base_test import DerandomizedTestCase, MockModel, assert_warns
4042

41-
from qinfer.utils import in_ellipsoid, assert_sigfigs_equal
43+
from qinfer.utils import in_ellipsoid, assert_sigfigs_equal, sqrtm_psd
4244

4345
## TESTS #####################################################################
4446

@@ -80,7 +82,7 @@ def test_assert_sigfigs_equal(self):
8082
np.array([1728]),
8183
4
8284
)
83-
85+
8486

8587
class TestEllipsoids(DerandomizedTestCase):
8688

@@ -116,4 +118,26 @@ def test_in_ellipsoid(self):
116118
assert_equal(
117119
in_ellipsoid(x, A, c),
118120
np.array([1,0,1,0], dtype=bool)
119-
)
121+
)
122+
123+
class TestLinearAlgebra(DerandomizedTestCase):
124+
def test_sqrtm_psd(self):
125+
# Construct Y = XX^T as a PSD matrix.
126+
X = np.random.random((5, 5))
127+
Y = np.dot(X, X.T)
128+
sqrt_Y = sqrtm_psd(Y, est_error=False)
129+
130+
np.testing.assert_allclose(
131+
np.dot(sqrt_Y, sqrt_Y),
132+
Y
133+
)
134+
135+
# Try again, but with a singular matrix.
136+
Y_singular = np.zeros((6, 6))
137+
Y_singular[:5, :5] = Y
138+
sqrt_Y_singular = sqrtm_psd(Y_singular, est_error=False)
139+
140+
np.testing.assert_allclose(
141+
np.dot(sqrt_Y_singular, sqrt_Y_singular),
142+
Y_singular
143+
)

src/qinfer/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import numpy as np
3939
import numpy.linalg as la
4040

41+
from scipy.linalg import eigh
42+
4143
from scipy.stats import logistic, binom
4244
from scipy.special import gammaln, gamma
4345
from scipy.linalg import sqrtm
@@ -455,6 +457,21 @@ def safe_shape(arr, idx=0, default=1):
455457
shape = np.shape(arr)
456458
return shape[idx] if idx < len(shape) else default
457459

460+
def sqrtm_psd(A, est_error=True, check_finite=True):
461+
"""
462+
Returns the matrix square root of a positive semidefinite matrix,
463+
truncating negative eigenvalues.
464+
"""
465+
w, v = eigh(A, check_finite=check_finite)
466+
mask = w <= 0
467+
w[mask] = 0
468+
np.sqrt(w, out=w)
469+
A_sqrt = (v * w).dot(v.conj().T)
470+
471+
if est_error:
472+
return A_sqrt, np.linalg.norm(np.dot(A_sqrt, A_sqrt) - A, 'fro')
473+
else:
474+
return A_sqrt
458475

459476
#==============================================================================
460477
#Test Code

0 commit comments

Comments
 (0)