Skip to content

Commit 617adaa

Browse files
wdevazelhesWilliam de Vazelhes 80055062
and
William de Vazelhes 80055062
authored
[MRG] Fix test for components_from_metric and add tests for _check_sdp_from_eigen (#303)
* Fix test for components_from_metric and add tests for _check_sdp_from_eigen * Fix trailing whitespace * Remove unused LinAlgError Co-authored-by: William de Vazelhes 80055062 <[email protected]>
1 parent 86a5208 commit 617adaa

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

test/test_components_metric_conversion.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22
import numpy as np
33
import pytest
4-
from numpy.linalg import LinAlgError
54
from scipy.stats import ortho_group
65
from sklearn.datasets import load_iris
76
from numpy.testing import assert_array_almost_equal, assert_allclose
@@ -117,17 +116,14 @@ def test_components_from_metric_edge_cases(self):
117116
L = components_from_metric(M)
118117
assert_allclose(L.T.dot(L), M)
119118

120-
# matrix with a determinant still high but which should be considered as a
121-
# non-definite matrix (to check we don't test the definiteness with the
122-
# determinant which is a bad strategy)
119+
# matrix with a determinant still high but which is
120+
# undefinite w.r.t to numpy standards
123121
M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20])
124122
M = P.dot(M).dot(P.T)
125123
assert np.abs(np.linalg.det(M)) > 10
126124
assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed
127125
# determinant is far from null)
128-
with pytest.raises(LinAlgError) as err_msg:
129-
np.linalg.cholesky(M)
130-
assert str(err_msg.value) == 'Matrix is not positive definite'
126+
assert np.linalg.matrix_rank(M) < M.shape[0]
131127
# (just to show that this case is indeed considered by numpy as an
132128
# indefinite case)
133129
L = components_from_metric(M)

test/test_utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,53 @@ def test__check_sdp_from_eigen_returns_definiteness(w, is_definite):
10551055
assert _check_sdp_from_eigen(w) == is_definite
10561056

10571057

1058+
@pytest.mark.unit
1059+
@pytest.mark.parametrize('w, tol, is_definite',
1060+
[(np.array([5., 3.]), 2, True),
1061+
(np.array([5., 1.]), 2, False),
1062+
(np.array([5., -1.]), 2, False)])
1063+
def test__check_sdp_from_eigen_tol_psd(w, tol, is_definite):
1064+
"""Tests that _check_sdp_from_eigen, for PSD matrices, returns
1065+
False if an eigenvalue is lower than tol"""
1066+
assert _check_sdp_from_eigen(w, tol=tol) == is_definite
1067+
1068+
1069+
@pytest.mark.unit
1070+
@pytest.mark.parametrize('w, tol',
1071+
[(np.array([5., -3.]), 2),
1072+
(np.array([1., -3.]), 2)])
1073+
def test__check_sdp_from_eigen_tol_non_psd(w, tol):
1074+
"""Tests that _check_sdp_from_eigen raises a NonPSDError
1075+
when there is a negative value with abs value higher than tol"""
1076+
with pytest.raises(NonPSDError):
1077+
_check_sdp_from_eigen(w, tol=tol)
1078+
1079+
1080+
@pytest.mark.unit
1081+
@pytest.mark.parametrize('w, is_definite',
1082+
[(np.array([1e5, 1e5, 1e5, 1e5,
1083+
1e5, 1e5, 1e-20]), False),
1084+
(np.array([1e-10, 1e-10]), True)])
1085+
def test__check_sdp_from_eigen_tol_default_psd(w, is_definite):
1086+
"""Tests that the default tol argument gives good results for edge cases
1087+
like even if the determinant is high but clearly one eigenvalue is low,
1088+
(undefinite so returns False) or when all eigenvalues are low (definite so
1089+
returns True)"""
1090+
assert _check_sdp_from_eigen(w, tol=None) == is_definite
1091+
1092+
1093+
@pytest.mark.unit
1094+
@pytest.mark.parametrize('w',
1095+
[np.array([1., -1.]),
1096+
np.array([-1e-10, 1e-10])])
1097+
def test__check_sdp_from_eigen_tol_default_non_psd(w):
1098+
"""Tests that the default tol argument is good for raising
1099+
NonPSDError, e.g. that when a value is clearly relatively
1100+
negative it raises such an error"""
1101+
with pytest.raises(NonPSDError):
1102+
_check_sdp_from_eigen(w, tol=None)
1103+
1104+
10581105
def test__check_n_components():
10591106
"""Checks that n_components returns what is expected
10601107
(including the errors)"""

0 commit comments

Comments
 (0)