Skip to content

Commit 7102bc5

Browse files
Resolve some SCML issues
The sklearn KMeans class now warns if you don't provide a value for the n_init parameter. I'm setting it to the original default, but we may want to consider setting it to 'auto' in the future.
1 parent 8d5059c commit 7102bc5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

metric_learn/scml.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def _initialize_basis_supervised(self, X, y):
558558
case one is selected.
559559
"""
560560

561-
if self.basis == 'lda':
561+
if isinstance(self.basis, str) and self.basis == 'lda':
562562
basis, n_basis = self._generate_bases_LDA(X, y)
563563
else:
564564
basis, n_basis = None, None
@@ -606,8 +606,8 @@ def _generate_bases_LDA(self, X, y):
606606
"should be smaller than %d" %
607607
(n_basis, X.shape[0]*2*num_eig))
608608

609-
kmeans = KMeans(n_clusters=n_clusters, random_state=self.random_state,
610-
algorithm='elkan').fit(X)
609+
kmeans = KMeans(n_clusters=n_clusters, n_init=10,
610+
random_state=self.random_state, algorithm='elkan').fit(X)
611611
cX = kmeans.cluster_centers_
612612

613613
n_scales = 2

0 commit comments

Comments
 (0)