Skip to content

Commit 96b628e

Browse files
committed
flann returns the squared distance when called with 'euclidean' distance -> fix
fix radius nn graph for spdist
1 parent 25ec6d2 commit 96b628e

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

pygsp/graphs/nngraphs/nngraph.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _knn_flann(X, num_neighbors, dist_type, order):
4949
# do not allow it
5050
if dist_type == 'max_dist':
5151
raise ValueError('FLANN and max_dist is not supported')
52+
5253
pfl = _import_pfl()
5354
pfl.set_distance_type(dist_type, order=order)
5455
flann = pfl.FLANN()
@@ -58,6 +59,8 @@ def _knn_flann(X, num_neighbors, dist_type, order):
5859
# seems to work best).
5960
NN, D = flann.nn(X, X, num_neighbors=(num_neighbors + 1),
6061
algorithm='kdtree')
62+
if dist_type == 'euclidean': # flann returns squared distances
63+
return NN, np.sqrt(D)
6164
return NN, D
6265

6366
def _radius_sp_kdtree(X, epsilon, dist_type, order=0):
@@ -86,8 +89,9 @@ def _radius_sp_pdist(X, epsilon, dist_type, order):
8689
NN = []
8790
for k in range(N):
8891
v = pd[k, pdf[k, :]]
92+
d = pd[k, :].argsort()
8993
# use the same conventions as in scipy.distance.kdtree
90-
NN.append(v.argsort())
94+
NN.append(d[0:len(v)])
9195
D.append(np.sort(v))
9296

9397
return NN, D
@@ -98,21 +102,32 @@ def _radius_flann(X, epsilon, dist_type, order=0):
98102
# do not allow it
99103
if dist_type == 'max_dist':
100104
raise ValueError('FLANN and max_dist is not supported')
101-
102105
pfl = _import_pfl()
106+
103107
pfl.set_distance_type(dist_type, order=order)
104108
flann = pfl.FLANN()
105109
flann.build_index(X)
106110

107111
D = []
108112
NN = []
109113
for k in range(N):
110-
nn, d = flann.nn_radius(X[k, :], epsilon)
114+
nn, d = flann.nn_radius(X[k, :], epsilon*epsilon)
111115
D.append(d)
112116
NN.append(nn)
113117
flann.delete_index()
118+
if dist_type == 'euclidean': # flann returns squared distances
119+
return NN, np.sqrt(D)
114120
return NN, D
115121

122+
def center_input(X, N):
123+
return X - np.kron(np.ones((N, 1)), np.mean(X, axis=0))
124+
125+
def rescale_input(X, N, d):
126+
bounding_radius = 0.5 * np.linalg.norm(np.amax(X, axis=0) -
127+
np.amin(X, axis=0), 2)
128+
scale = np.power(N, 1. / float(min(d, 3))) / 10.
129+
return X * scale / bounding_radius
130+
116131
class NNGraph(Graph):
117132
r"""Nearest-neighbor graph from given point cloud.
118133
@@ -207,14 +222,10 @@ def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
207222
Xout = self.Xin
208223

209224
if self.center:
210-
Xout = self.Xin - np.kron(np.ones((N, 1)),
211-
np.mean(self.Xin, axis=0))
225+
Xout = center_input(Xout, N)
212226

213227
if self.rescale:
214-
bounding_radius = 0.5 * np.linalg.norm(np.amax(Xout, axis=0) -
215-
np.amin(Xout, axis=0), 2)
216-
scale = np.power(N, 1. / float(min(d, 3))) / 10.
217-
Xout *= scale / bounding_radius
228+
Xout = rescale_input(Xout, N, d)
218229

219230

220231

pygsp/tests/test_graphs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def test_nngraph(self):
187187

188188
for cur_backend in backends:
189189
for dist_type in dist_types:
190+
#print("backend={} dist={}".format(cur_backend, dist_type))
190191
if cur_backend == 'flann' and dist_type == 'max_dist':
191192
self.assertRaises(ValueError, graphs.NNGraph, Xin,
192193
NNtype='knn', backend=cur_backend,
@@ -199,6 +200,20 @@ def test_nngraph(self):
199200
backend=cur_backend,
200201
dist_type=dist_type, order=order)
201202

203+
def test_nngraph_consistency(self):
204+
#Xin = np.arange(180).reshape(60, 3)
205+
Xin = np.random.uniform(-5, 5, (60, 3))
206+
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
207+
backends = ['scipy-kdtree', 'flann']
208+
num_neighbors=5
209+
210+
G = graphs.NNGraph(Xin, NNtype='knn',
211+
backend='scipy-pdist', k=num_neighbors)
212+
for cur_backend in backends:
213+
for dist_type in dist_types:
214+
Gt = graphs.NNGraph(Xin, NNtype='knn',
215+
backend=cur_backend, k=num_neighbors)
216+
202217
def test_bunny(self):
203218
graphs.Bunny()
204219

0 commit comments

Comments
 (0)