|
4 | 4 | import metric_learn
|
5 | 5 | import numpy as np
|
6 | 6 | from sklearn import clone
|
7 |
| -from sklearn.utils.testing import set_random_state |
8 | 7 | from test.test_utils import ids_metric_learners, metric_learners, remove_y
|
| 8 | +from metric_learn.sklearn_shims import set_random_state, SKLEARN_AT_LEAST_0_22 |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def remove_spaces(s):
|
12 | 12 | return re.sub(r'\s+', '', s)
|
13 | 13 |
|
14 | 14 |
|
| 15 | +def sk_repr_kwargs(def_kwargs, nndef_kwargs): |
| 16 | + """Given the non-default arguments, and the default |
| 17 | + keywords arguments, build the string that will appear |
| 18 | + in the __repr__ of the estimator, depending on the |
| 19 | + version of scikit-learn. |
| 20 | + """ |
| 21 | + if SKLEARN_AT_LEAST_0_22: |
| 22 | + def_kwargs = {} |
| 23 | + def_kwargs.update(nndef_kwargs) |
| 24 | + args_str = ",".join(f"{key}={repr(value)}" |
| 25 | + for key, value in def_kwargs.items()) |
| 26 | + return args_str |
| 27 | + |
| 28 | + |
15 | 29 | class TestStringRepr(unittest.TestCase):
|
16 | 30 |
|
17 | 31 | def test_covariance(self):
|
| 32 | + def_kwargs = {'preprocessor': None} |
| 33 | + nndef_kwargs = {} |
| 34 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
18 | 35 | self.assertEqual(remove_spaces(str(metric_learn.Covariance())),
|
19 |
| - remove_spaces("Covariance()")) |
| 36 | + remove_spaces(f"Covariance({merged_kwargs})")) |
20 | 37 |
|
21 | 38 | def test_lmnn(self):
|
| 39 | + def_kwargs = {'convergence_tol': 0.001, 'init': 'auto', 'k': 3, |
| 40 | + 'learn_rate': 1e-07, 'max_iter': 1000, 'min_iter': 50, |
| 41 | + 'n_components': None, 'preprocessor': None, |
| 42 | + 'random_state': None, 'regularization': 0.5, |
| 43 | + 'verbose': False} |
| 44 | + nndef_kwargs = {'convergence_tol': 0.01, 'k': 6} |
| 45 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
22 | 46 | self.assertEqual(
|
23 | 47 | remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, k=6))),
|
24 |
| - remove_spaces("LMNN(convergence_tol=0.01, k=6)")) |
| 48 | + remove_spaces(f"LMNN({merged_kwargs})")) |
25 | 49 |
|
26 | 50 | def test_nca(self):
|
| 51 | + def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None, |
| 52 | + 'preprocessor': None, 'random_state': None, 'tol': None, |
| 53 | + 'verbose': False} |
| 54 | + nndef_kwargs = {'max_iter': 42} |
| 55 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
27 | 56 | self.assertEqual(remove_spaces(str(metric_learn.NCA(max_iter=42))),
|
28 |
| - remove_spaces("NCA(max_iter=42)")) |
| 57 | + remove_spaces(f"NCA({merged_kwargs})")) |
29 | 58 |
|
30 | 59 | def test_lfda(self):
|
| 60 | + def_kwargs = {'embedding_type': 'weighted', 'k': None, |
| 61 | + 'n_components': None, 'preprocessor': None} |
| 62 | + nndef_kwargs = {'k': 2} |
| 63 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
31 | 64 | self.assertEqual(remove_spaces(str(metric_learn.LFDA(k=2))),
|
32 |
| - remove_spaces("LFDA(k=2)")) |
| 65 | + remove_spaces(f"LFDA({merged_kwargs})")) |
33 | 66 |
|
34 | 67 | def test_itml(self):
|
| 68 | + def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0, |
| 69 | + 'max_iter': 1000, 'preprocessor': None, |
| 70 | + 'prior': 'identity', 'random_state': None, 'verbose': False} |
| 71 | + nndef_kwargs = {'gamma': 0.5} |
| 72 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
35 | 73 | self.assertEqual(remove_spaces(str(metric_learn.ITML(gamma=0.5))),
|
36 |
| - remove_spaces("ITML(gamma=0.5)")) |
| 74 | + remove_spaces(f"ITML({merged_kwargs})")) |
| 75 | + def_kwargs = {'convergence_threshold': 0.001, 'gamma': 1.0, |
| 76 | + 'max_iter': 1000, 'num_constraints': None, |
| 77 | + 'preprocessor': None, 'prior': 'identity', |
| 78 | + 'random_state': None, 'verbose': False} |
| 79 | + nndef_kwargs = {'num_constraints': 7} |
| 80 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
37 | 81 | self.assertEqual(
|
38 | 82 | remove_spaces(str(metric_learn.ITML_Supervised(num_constraints=7))),
|
39 |
| - remove_spaces("ITML_Supervised(num_constraints=7)")) |
| 83 | + remove_spaces(f"ITML_Supervised({merged_kwargs})")) |
40 | 84 |
|
41 | 85 | def test_lsml(self):
|
| 86 | + def_kwargs = {'max_iter': 1000, 'preprocessor': None, 'prior': 'identity', |
| 87 | + 'random_state': None, 'tol': 0.001, 'verbose': False} |
| 88 | + nndef_kwargs = {'tol': 0.1} |
| 89 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
42 | 90 | self.assertEqual(remove_spaces(str(metric_learn.LSML(tol=0.1))),
|
43 |
| - remove_spaces("LSML(tol=0.1)")) |
| 91 | + remove_spaces(f"LSML({merged_kwargs})")) |
| 92 | + def_kwargs = {'max_iter': 1000, 'num_constraints': None, |
| 93 | + 'preprocessor': None, 'prior': 'identity', |
| 94 | + 'random_state': None, 'tol': 0.001, 'verbose': False, |
| 95 | + 'weights': None} |
| 96 | + nndef_kwargs = {'verbose': True} |
| 97 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
44 | 98 | self.assertEqual(
|
45 | 99 | remove_spaces(str(metric_learn.LSML_Supervised(verbose=True))),
|
46 |
| - remove_spaces("LSML_Supervised(verbose=True)")) |
| 100 | + remove_spaces(f"LSML_Supervised({merged_kwargs})")) |
47 | 101 |
|
48 | 102 | def test_sdml(self):
|
| 103 | + def_kwargs = {'balance_param': 0.5, 'preprocessor': None, |
| 104 | + 'prior': 'identity', 'random_state': None, |
| 105 | + 'sparsity_param': 0.01, 'verbose': False} |
| 106 | + nndef_kwargs = {'verbose': True} |
| 107 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
49 | 108 | self.assertEqual(remove_spaces(str(metric_learn.SDML(verbose=True))),
|
50 |
| - remove_spaces("SDML(verbose=True)")) |
| 109 | + remove_spaces(f"SDML({merged_kwargs})")) |
| 110 | + def_kwargs = {'balance_param': 0.5, 'num_constraints': None, |
| 111 | + 'preprocessor': None, 'prior': 'identity', |
| 112 | + 'random_state': None, 'sparsity_param': 0.01, |
| 113 | + 'verbose': False} |
| 114 | + nndef_kwargs = {'sparsity_param': 0.5} |
| 115 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
51 | 116 | self.assertEqual(
|
52 | 117 | remove_spaces(str(metric_learn.SDML_Supervised(sparsity_param=0.5))),
|
53 |
| - remove_spaces("SDML_Supervised(sparsity_param=0.5)")) |
| 118 | + remove_spaces(f"SDML_Supervised({merged_kwargs})")) |
54 | 119 |
|
55 | 120 | def test_rca(self):
|
| 121 | + def_kwargs = {'n_components': None, 'preprocessor': None} |
| 122 | + nndef_kwargs = {'n_components': 3} |
| 123 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
56 | 124 | self.assertEqual(remove_spaces(str(metric_learn.RCA(n_components=3))),
|
57 |
| - remove_spaces("RCA(n_components=3)")) |
| 125 | + remove_spaces(f"RCA({merged_kwargs})")) |
| 126 | + def_kwargs = {'chunk_size': 2, 'n_components': None, 'num_chunks': 100, |
| 127 | + 'preprocessor': None, 'random_state': None} |
| 128 | + nndef_kwargs = {'num_chunks': 5} |
| 129 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
58 | 130 | self.assertEqual(
|
59 | 131 | remove_spaces(str(metric_learn.RCA_Supervised(num_chunks=5))),
|
60 |
| - remove_spaces("RCA_Supervised(num_chunks=5)")) |
| 132 | + remove_spaces(f"RCA_Supervised({merged_kwargs})")) |
61 | 133 |
|
62 | 134 | def test_mlkr(self):
|
| 135 | + def_kwargs = {'init': 'auto', 'max_iter': 1000, |
| 136 | + 'n_components': None, 'preprocessor': None, |
| 137 | + 'random_state': None, 'tol': None, 'verbose': False} |
| 138 | + nndef_kwargs = {'max_iter': 777} |
| 139 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
63 | 140 | self.assertEqual(remove_spaces(str(metric_learn.MLKR(max_iter=777))),
|
64 |
| - remove_spaces("MLKR(max_iter=777)")) |
| 141 | + remove_spaces(f"MLKR({merged_kwargs})")) |
65 | 142 |
|
66 | 143 | def test_mmc(self):
|
| 144 | + def_kwargs = {'convergence_threshold': 0.001, 'diagonal': False, |
| 145 | + 'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, |
| 146 | + 'max_proj': 10000, 'preprocessor': None, |
| 147 | + 'random_state': None, 'verbose': False} |
| 148 | + nndef_kwargs = {'diagonal': True} |
| 149 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
67 | 150 | self.assertEqual(remove_spaces(str(metric_learn.MMC(diagonal=True))),
|
68 |
| - remove_spaces("MMC(diagonal=True)")) |
| 151 | + remove_spaces(f"MMC({merged_kwargs})")) |
| 152 | + def_kwargs = {'convergence_threshold': 1e-06, 'diagonal': False, |
| 153 | + 'diagonal_c': 1.0, 'init': 'identity', 'max_iter': 100, |
| 154 | + 'max_proj': 10000, 'num_constraints': None, |
| 155 | + 'preprocessor': None, 'random_state': None, |
| 156 | + 'verbose': False} |
| 157 | + nndef_kwargs = {'max_iter': 1} |
| 158 | + merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) |
69 | 159 | self.assertEqual(
|
70 | 160 | remove_spaces(str(metric_learn.MMC_Supervised(max_iter=1))),
|
71 |
| - remove_spaces("MMC_Supervised(max_iter=1)")) |
| 161 | + remove_spaces(f"MMC_Supervised({merged_kwargs})")) |
72 | 162 |
|
73 | 163 |
|
74 | 164 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners,
|
|
0 commit comments