Skip to content

Commit 7eef7c6

Browse files
wdevazelhesWilliam de Vazelhes 80055062
and
William de Vazelhes 80055062
authored
[MRG+2] Update repo to work with both new and old scikit-learn (#313)
* Update repo to work with both new and old scikit-learn, and add a travis job to test the old scikit-learn * Add random seed for test * fix str representation for various sklearn versions * fix flake8 error * add more samples to the test for robustness * add comment for additional travis test * change str to dict for simplicity * simplify conditional imports with sklearn_shims file * remove typo added blankline * empty commit Co-authored-by: William de Vazelhes 80055062 <[email protected]>
1 parent 66a12ed commit 7eef7c6

11 files changed

+170
-42
lines changed

.travis.yml

+13
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ matrix:
3939
- pytest test --cov;
4040
after_success:
4141
- bash <(curl -s https://codecov.io/bash)
42+
- name: "Pytest python 3.6 with skggm + scikit-learn 0.20.3"
43+
# checks that tests work for the oldest supported scikit-learn version
44+
python: "3.6"
45+
before_install:
46+
- sudo apt-get install liblapack-dev
47+
- pip install --upgrade pip pytest
48+
- pip install wheel cython numpy scipy codecov pytest-cov
49+
- pip install scikit-learn==0.20.3
50+
- pip install git+https://github.com/skggm/skggm.git@${SKGGM_VERSION};
51+
script:
52+
- pytest test --cov;
53+
after_success:
54+
- bash <(curl -s https://codecov.io/bash)
4255
- name: "Syntax checking with flake8"
4356
python: "3.7"
4457
before_install:

metric_learn/sklearn_shims.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""This file is for fixing imports due to different APIs
2+
depending on the scikit-learn version"""
3+
import sklearn
4+
from packaging import version
5+
SKLEARN_AT_LEAST_0_22 = (version.parse(sklearn.__version__)
6+
>= version.parse('0.22.0'))
7+
if SKLEARN_AT_LEAST_0_22:
8+
from sklearn.utils._testing import (set_random_state,
9+
assert_warns_message,
10+
ignore_warnings,
11+
assert_allclose_dense_sparse,
12+
_get_args)
13+
from sklearn.utils.estimator_checks import (_is_public_parameter
14+
as is_public_parameter)
15+
from sklearn.metrics._scorer import get_scorer
16+
else:
17+
from sklearn.utils.testing import (set_random_state,
18+
assert_warns_message,
19+
ignore_warnings,
20+
assert_allclose_dense_sparse,
21+
_get_args)
22+
from sklearn.utils.estimator_checks import is_public_parameter
23+
from sklearn.metrics.scorer import get_scorer
24+
25+
__all__ = ['set_random_state', 'assert_warns_message', 'set_random_state',
26+
'ignore_warnings', 'assert_allclose_dense_sparse', '_get_args',
27+
'is_public_parameter', 'get_scorer']

test/metric_learn_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
make_spd_matrix)
1010
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
1111
assert_allclose)
12-
from sklearn.utils.testing import assert_warns_message
12+
from metric_learn.sklearn_shims import assert_warns_message
1313
from sklearn.exceptions import ConvergenceWarning
1414
from sklearn.utils.validation import check_X_y
1515
from sklearn.preprocessing import StandardScaler

test/test_base_metric.py

+106-16
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,161 @@
44
import metric_learn
55
import numpy as np
66
from sklearn import clone
7-
from sklearn.utils.testing import set_random_state
87
from test.test_utils import ids_metric_learners, metric_learners, remove_y
8+
from metric_learn.sklearn_shims import set_random_state, SKLEARN_AT_LEAST_0_22
99

1010

1111
def remove_spaces(s):
1212
return re.sub(r'\s+', '', s)
1313

1414

15+
def sk_repr_kwargs(def_kwargs, nndef_kwargs):
16+
"""Given the non-default arguments, and the default
17+
keywords arguments, build the string that will appear
18+
in the __repr__ of the estimator, depending on the
19+
version of scikit-learn.
20+
"""
21+
if SKLEARN_AT_LEAST_0_22:
22+
def_kwargs = {}
23+
def_kwargs.update(nndef_kwargs)
24+
args_str = ",".join(f"{key}={repr(value)}"
25+
for key, value in def_kwargs.items())
26+
return args_str
27+
28+
1529
class TestStringRepr(unittest.TestCase):
1630

1731
def test_covariance(self):
32+
def_kwargs = {'preprocessor': None}
33+
nndef_kwargs = {}
34+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
1835
self.assertEqual(remove_spaces(str(metric_learn.Covariance())),
19-
remove_spaces("Covariance()"))
36+
remove_spaces(f"Covariance({merged_kwargs})"))
2037

2138
def test_lmnn(self):
39+
def_kwargs = {'convergence_tol': 0.001, 'init': 'auto', 'k': 3,
40+
'learn_rate': 1e-07, 'max_iter': 1000, 'min_iter': 50,
41+
'n_components': None, 'preprocessor': None,
42+
'random_state': None, 'regularization': 0.5,
43+
'verbose': False}
44+
nndef_kwargs = {'convergence_tol': 0.01, 'k': 6}
45+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
2246
self.assertEqual(
2347
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, k=6))),
24-
remove_spaces("LMNN(convergence_tol=0.01, k=6)"))
48+
remove_spaces(f"LMNN({merged_kwargs})"))
2549

2650
def test_nca(self):
51+
def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None,
52+
'preprocessor': None, 'random_state': None, 'tol': None,
53+
'verbose': False}
54+
nndef_kwargs = {'max_iter': 42}
55+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
2756
self.assertEqual(remove_spaces(str(metric_learn.NCA(max_iter=42))),
28-
remove_spaces("NCA(max_iter=42)"))
57+
remove_spaces(f"NCA({merged_kwargs})"))
2958

3059
def test_lfda(self):
60+
def_kwargs = {'embedding_type': 'weighted', 'k': None,
61+
'n_components': None, 'preprocessor': None}
62+
nndef_kwargs = {'k': 2}
63+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
3164
self.assertEqual(remove_spaces(str(metric_learn.LFDA(k=2))),
32-
remove_spaces("LFDA(k=2)"))
65+
remove_spaces(f"LFDA({merged_kwargs})"))
3366

3467
def test_itml(self):
68+
def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0,
69+
'max_iter': 1000, 'preprocessor': None,
70+
'prior': 'identity', 'random_state': None, 'verbose': False}
71+
nndef_kwargs = {'gamma': 0.5}
72+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
3573
self.assertEqual(remove_spaces(str(metric_learn.ITML(gamma=0.5))),
36-
remove_spaces("ITML(gamma=0.5)"))
74+
remove_spaces(f"ITML({merged_kwargs})"))
75+
def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0,
76+
'max_iter': 1000, 'num_constraints': None,
77+
'preprocessor': None, 'prior': 'identity',
78+
'random_state': None, 'verbose': False}
79+
nndef_kwargs = {'num_constraints': 7}
80+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
3781
self.assertEqual(
3882
remove_spaces(str(metric_learn.ITML_Supervised(num_constraints=7))),
39-
remove_spaces("ITML_Supervised(num_constraints=7)"))
83+
remove_spaces(f"ITML_Supervised({merged_kwargs})"))
4084

4185
def test_lsml(self):
86+
def_kwargs = {'max_iter': 1000, 'preprocessor': None, 'prior': 'identity',
87+
'random_state': None, 'tol': 0.001, 'verbose': False}
88+
nndef_kwargs = {'tol': 0.1}
89+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
4290
self.assertEqual(remove_spaces(str(metric_learn.LSML(tol=0.1))),
43-
remove_spaces("LSML(tol=0.1)"))
91+
remove_spaces(f"LSML({merged_kwargs})"))
92+
def_kwargs = {'max_iter': 1000, 'num_constraints': None,
93+
'preprocessor': None, 'prior': 'identity',
94+
'random_state': None, 'tol': 0.001, 'verbose': False,
95+
'weights': None}
96+
nndef_kwargs = {'verbose': True}
97+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
4498
self.assertEqual(
4599
remove_spaces(str(metric_learn.LSML_Supervised(verbose=True))),
46-
remove_spaces("LSML_Supervised(verbose=True)"))
100+
remove_spaces(f"LSML_Supervised({merged_kwargs})"))
47101

48102
def test_sdml(self):
103+
def_kwargs = {'balance_param': 0.5, 'preprocessor': None,
104+
'prior': 'identity', 'random_state': None,
105+
'sparsity_param': 0.01, 'verbose': False}
106+
nndef_kwargs = {'verbose': True}
107+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
49108
self.assertEqual(remove_spaces(str(metric_learn.SDML(verbose=True))),
50-
remove_spaces("SDML(verbose=True)"))
109+
remove_spaces(f"SDML({merged_kwargs})"))
110+
def_kwargs = {'balance_param': 0.5, 'num_constraints': None,
111+
'preprocessor': None, 'prior': 'identity',
112+
'random_state': None, 'sparsity_param': 0.01,
113+
'verbose': False}
114+
nndef_kwargs = {'sparsity_param': 0.5}
115+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
51116
self.assertEqual(
52117
remove_spaces(str(metric_learn.SDML_Supervised(sparsity_param=0.5))),
53-
remove_spaces("SDML_Supervised(sparsity_param=0.5)"))
118+
remove_spaces(f"SDML_Supervised({merged_kwargs})"))
54119

55120
def test_rca(self):
121+
def_kwargs = {'n_components': None, 'preprocessor': None}
122+
nndef_kwargs = {'n_components': 3}
123+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
56124
self.assertEqual(remove_spaces(str(metric_learn.RCA(n_components=3))),
57-
remove_spaces("RCA(n_components=3)"))
125+
remove_spaces(f"RCA({merged_kwargs})"))
126+
def_kwargs = {'chunk_size': 2, 'n_components': None, 'num_chunks': 100,
127+
'preprocessor': None, 'random_state': None}
128+
nndef_kwargs = {'num_chunks': 5}
129+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
58130
self.assertEqual(
59131
remove_spaces(str(metric_learn.RCA_Supervised(num_chunks=5))),
60-
remove_spaces("RCA_Supervised(num_chunks=5)"))
132+
remove_spaces(f"RCA_Supervised({merged_kwargs})"))
61133

62134
def test_mlkr(self):
135+
def_kwargs = {'init': 'auto', 'max_iter': 1000,
136+
'n_components': None, 'preprocessor': None,
137+
'random_state': None, 'tol': None, 'verbose': False}
138+
nndef_kwargs = {'max_iter': 777}
139+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
63140
self.assertEqual(remove_spaces(str(metric_learn.MLKR(max_iter=777))),
64-
remove_spaces("MLKR(max_iter=777)"))
141+
remove_spaces(f"MLKR({merged_kwargs})"))
65142

66143
def test_mmc(self):
144+
def_kwargs = {'convergence_threshold': 0.001, 'diagonal': False,
145+
'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100,
146+
'max_proj': 10000, 'preprocessor': None,
147+
'random_state': None, 'verbose': False}
148+
nndef_kwargs = {'diagonal': True}
149+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
67150
self.assertEqual(remove_spaces(str(metric_learn.MMC(diagonal=True))),
68-
remove_spaces("MMC(diagonal=True)"))
151+
remove_spaces(f"MMC({merged_kwargs})"))
152+
def_kwargs = {'convergence_threshold': 1e-06, 'diagonal': False,
153+
'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100,
154+
'max_proj': 10000, 'num_constraints': None,
155+
'preprocessor': None, 'random_state': None,
156+
'verbose': False}
157+
nndef_kwargs = {'max_iter': 1}
158+
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
69159
self.assertEqual(
70160
remove_spaces(str(metric_learn.MMC_Supervised(max_iter=1))),
71-
remove_spaces("MMC_Supervised(max_iter=1)"))
161+
remove_spaces(f"MMC_Supervised({merged_kwargs})"))
72162

73163

74164
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,

test/test_components_metric_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from scipy.stats import ortho_group
55
from sklearn.datasets import load_iris
66
from numpy.testing import assert_array_almost_equal, assert_allclose
7-
from sklearn.utils.testing import ignore_warnings
7+
from metric_learn.sklearn_shims import ignore_warnings
88

99
from metric_learn import (
1010
LMNN, NCA, LFDA, Covariance, MLKR,

test/test_mahalanobis_mixin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.datasets import make_spd_matrix, make_blobs
1212
from sklearn.utils import check_random_state, shuffle
1313
from sklearn.utils.multiclass import type_of_target
14-
from sklearn.utils.testing import set_random_state
14+
from metric_learn.sklearn_shims import set_random_state
1515

1616
from metric_learn._util import make_context, _initialize_metric_mahalanobis
1717
from metric_learn.base_metric import (_QuadrupletsClassifierMixin,

test/test_pairs_classifiers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.model_selection import train_test_split
1212

1313
from test.test_utils import pairs_learners, ids_pairs_learners
14-
from sklearn.utils.testing import set_random_state
14+
from metric_learn.sklearn_shims import set_random_state
1515
from sklearn import clone
1616
import numpy as np
1717
from itertools import product

test/test_quadruplets_classifiers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sklearn.model_selection import train_test_split
44

55
from test.test_utils import quadruplets_learners, ids_quadruplets_learners
6-
from sklearn.utils.testing import set_random_state
6+
from metric_learn.sklearn_shims import set_random_state
77
from sklearn import clone
88
import numpy as np
99

test/test_sklearn_compat.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from sklearn.base import TransformerMixin
55
from sklearn.pipeline import make_pipeline
66
from sklearn.utils import check_random_state
7-
from sklearn.utils.estimator_checks import is_public_parameter
8-
from sklearn.utils.testing import (assert_allclose_dense_sparse,
9-
set_random_state)
10-
7+
from metric_learn.sklearn_shims import (assert_allclose_dense_sparse,
8+
set_random_state, _get_args,
9+
is_public_parameter, get_scorer)
1110
from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA,
1211
ITML_Supervised, LSML_Supervised,
1312
MMC_Supervised, RCA_Supervised, SDML_Supervised,
@@ -16,8 +15,6 @@
1615
import numpy as np
1716
from sklearn.model_selection import (cross_val_score, cross_val_predict,
1817
train_test_split, KFold)
19-
from sklearn.metrics.scorer import get_scorer
20-
from sklearn.utils.testing import _get_args
2118
from test.test_utils import (metric_learners, ids_metric_learners,
2219
mock_preprocessor, tuples_learners,
2320
ids_tuples_learners, pairs_learners,
@@ -52,37 +49,37 @@ def __init__(self, sparsity_param=0.01,
5249

5350
class TestSklearnCompat(unittest.TestCase):
5451
def test_covariance(self):
55-
check_estimator(Covariance)
52+
check_estimator(Covariance())
5653

5754
def test_lmnn(self):
58-
check_estimator(LMNN)
55+
check_estimator(LMNN())
5956

6057
def test_lfda(self):
61-
check_estimator(LFDA)
58+
check_estimator(LFDA())
6259

6360
def test_mlkr(self):
64-
check_estimator(MLKR)
61+
check_estimator(MLKR())
6562

6663
def test_nca(self):
67-
check_estimator(NCA)
64+
check_estimator(NCA())
6865

6966
def test_lsml(self):
70-
check_estimator(LSML_Supervised)
67+
check_estimator(LSML_Supervised())
7168

7269
def test_itml(self):
73-
check_estimator(ITML_Supervised)
70+
check_estimator(ITML_Supervised())
7471

7572
def test_mmc(self):
76-
check_estimator(MMC_Supervised)
73+
check_estimator(MMC_Supervised())
7774

7875
def test_sdml(self):
79-
check_estimator(Stable_SDML_Supervised)
76+
check_estimator(Stable_SDML_Supervised())
8077

8178
def test_rca(self):
82-
check_estimator(Stable_RCA_Supervised)
79+
check_estimator(Stable_RCA_Supervised())
8380

8481
def test_scml(self):
85-
check_estimator(SCML_Supervised)
82+
check_estimator(SCML_Supervised())
8683

8784

8885
RNG = check_random_state(0)
@@ -121,7 +118,8 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor):
121118

122119
# we subsample the data for the test to be more efficient
123120
input_data, _, labels, _ = train_test_split(input_data, labels,
124-
train_size=20)
121+
train_size=40,
122+
random_state=42)
125123
X = X[:10]
126124

127125
estimator = clone(estimator)
@@ -160,7 +158,7 @@ def test_various_scoring_on_tuples_learners(estimator, build_dataset,
160158
with_preprocessor):
161159
"""Tests that scikit-learn's scoring returns something finite,
162160
for other scoring than default scoring. (List of scikit-learn's scores can be
163-
found in sklearn.metrics.scorer). For each type of output (predict,
161+
found in sklearn.metrics._scorer). For each type of output (predict,
164162
predict_proba, decision_function), we test a bunch of scores.
165163
We only test on pairs learners because quadruplets don't have a y argument.
166164
"""

test/test_triplets_classifiers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sklearn.model_selection import train_test_split
44

55
from test.test_utils import triplets_learners, ids_triplets_learners
6-
from sklearn.utils.testing import set_random_state
6+
from metric_learn.sklearn_shims import set_random_state
77
from sklearn import clone
88
import numpy as np
99

test/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from numpy.testing import assert_array_equal, assert_equal
66
from sklearn.model_selection import train_test_split
77
from sklearn.utils import check_random_state, shuffle
8-
from sklearn.utils.testing import set_random_state
8+
from metric_learn.sklearn_shims import set_random_state
99
from sklearn.base import clone
1010
from metric_learn._util import (check_input, make_context, preprocess_tuples,
1111
make_name, preprocess_points,

0 commit comments

Comments
 (0)