Skip to content

Commit 1309e92

Browse files
committed
implement and use scipy-ckdtree as default (faster than kdtree)
1 parent 4a4d597 commit 1309e92

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

pygsp/graphs/nngraphs/nngraph.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
'manhattan': 1,
1919
'max_dist': np.inf
2020
},
21+
'scipy-ckdtree': {
22+
'euclidean': 2,
23+
'manhattan': 1,
24+
'max_dist': np.inf
25+
},
2126
'scipy-pdist' : {
2227
'euclidean': 'euclidean',
2328
'manhattan': 'cityblock',
@@ -44,6 +49,13 @@ def _knn_sp_kdtree(X, num_neighbors, dist_type, order=0):
4449
p=_dist_translation['scipy-kdtree'][dist_type])
4550
return NN, D
4651

52+
def _knn_sp_ckdtree(X, num_neighbors, dist_type, order=0):
53+
kdt = sps.cKDTree(X)
54+
D, NN = kdt.query(X, k=(num_neighbors + 1),
55+
p=_dist_translation['scipy-ckdtree'][dist_type])
56+
return NN, D
57+
58+
4759
def _knn_flann(X, num_neighbors, dist_type, order):
4860
# the combination FLANN + max_dist produces incorrect results
4961
# do not allow it
@@ -66,9 +78,27 @@ def _knn_flann(X, num_neighbors, dist_type, order):
6678
def _radius_sp_kdtree(X, epsilon, dist_type, order=0):
6779
kdt = sps.KDTree(X)
6880
D, NN = kdt.query(X, k=None, distance_upper_bound=epsilon,
69-
p=_dist_translation['scipy-kdtree'][dist_type])
81+
p=_dist_translation['scipy-kdtree'][dist_type])
7082
return NN, D
7183

84+
def _radius_sp_ckdtree(X, epsilon, dist_type, order=0):
85+
N, dim = np.shape(X)
86+
kdt = sps.cKDTree(X)
87+
nn = kdt.query_ball_point(X, r=epsilon,
88+
p=_dist_translation['scipy-ckdtree'][dist_type])
89+
D = []
90+
NN = []
91+
for k in range(N):
92+
x = np.matlib.repmat(X[k, :], len(nn[k]), 1)
93+
d = np.linalg.norm(x - X[nn[k], :],
94+
ord=_dist_translation['scipy-ckdtree'][dist_type],
95+
axis=1)
96+
nidx = d.argsort()
97+
NN.append(np.take(nn[k], nidx))
98+
D.append(np.sort(d))
99+
return NN, D
100+
101+
72102
def _knn_sp_pdist(X, num_neighbors, dist_type, order):
73103
pd = sps.distance.squareform(
74104
sps.distance.pdist(X,
@@ -142,7 +172,8 @@ class NNGraph(Graph):
142172
is 'knn').
143173
backend : {'scipy-kdtree', 'scipy-pdist', 'flann'}
144174
Type of the backend for graph construction.
145-
- 'scipy-kdtree'(default) will use scipy.spatial.KDTree
175+
- 'scipy-kdtree' will use scipy.spatial.KDTree
176+
- 'scipy-ckdtree'(default) will use scipy.spatial.cKDTree
146177
- 'scipy-pdist' will use scipy.spatial.distance.pdist (slowest but exact)
147178
- 'flann' use Fast Library for Approximate Nearest Neighbors (FLANN)
148179
center : bool, optional
@@ -183,7 +214,7 @@ class NNGraph(Graph):
183214
184215
"""
185216

186-
def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
217+
def __init__(self, Xin, NNtype='knn', backend='scipy-ckdtree', center=True,
187218
rescale=True, k=10, sigma=0.1, epsilon=0.01, gtype=None,
188219
plotting={}, symmetrize_type='average', dist_type='euclidean',
189220
order=0, **kwargs):
@@ -197,15 +228,18 @@ def __init__(self, Xin, NNtype='knn', backend='scipy-kdtree', center=True,
197228
self.sigma = sigma
198229
self.epsilon = epsilon
199230
_dist_translation['scipy-kdtree']['minkowski'] = order
231+
_dist_translation['scipy-ckdtree']['minkowski'] = order
200232

201233
self._nn_functions = {
202234
'knn': {
203235
'scipy-kdtree': _knn_sp_kdtree,
236+
'scipy-ckdtree': _knn_sp_ckdtree,
204237
'scipy-pdist': _knn_sp_pdist,
205238
'flann': _knn_flann
206239
},
207240
'radius': {
208241
'scipy-kdtree': _radius_sp_kdtree,
242+
'scipy-ckdtree': _radius_sp_ckdtree,
209243
'scipy-pdist': _radius_sp_pdist,
210244
'flann': _radius_flann
211245
},

pygsp/tests/test_graphs.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_set_coordinates(self):
183183
def test_nngraph(self):
184184
Xin = np.arange(90).reshape(30, 3)
185185
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
186-
backends = ['scipy-kdtree', 'scipy-pdist', 'flann']
186+
backends = ['scipy-kdtree', 'scipy-ckdtree', 'scipy-pdist', 'flann']
187187
order=3 # for minkowski, FLANN only accepts integer orders
188188

189189
for cur_backend in backends:
@@ -194,9 +194,10 @@ def test_nngraph(self):
194194
NNtype='knn', backend=cur_backend,
195195
dist_type=dist_type)
196196
else:
197-
graphs.NNGraph(Xin, NNtype='radius',
198-
backend=cur_backend,
199-
dist_type=dist_type, order=order)
197+
if cur_backend != 'flann': #pyflann fails on radius query
198+
graphs.NNGraph(Xin, NNtype='radius',
199+
backend=cur_backend,
200+
dist_type=dist_type, order=order)
200201
graphs.NNGraph(Xin, NNtype='knn',
201202
backend=cur_backend,
202203
dist_type=dist_type, order=order)
@@ -208,9 +209,9 @@ def test_nngraph(self):
208209
dist_type=dist_type)
209210

210211
def test_nngraph_consistency(self):
211-
Xin = np.random.uniform(-5, 5, (60, 3))
212+
Xin = np.arange(90).reshape(30, 3)
212213
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
213-
backends = ['scipy-kdtree', 'flann']
214+
backends = ['scipy-kdtree', 'scipy-ckdtree', 'flann']
214215
num_neighbors=4
215216
epsilon=0.1
216217

0 commit comments

Comments
 (0)