Skip to content

Commit 25ec6d2

Browse files
committed
implement radius nn graph with flann
1 parent 6f473fa commit 25ec6d2

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

pygsp/graphs/nngraphs/nngraph.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import traceback
44

55
import numpy as np
6-
from scipy import sparse, spatial
6+
from scipy import sparse
7+
import scipy.spatial as sps
78

89
from pygsp import utils
910
from pygsp.graphs import Graph # prevent circular import in Python < 3.5
@@ -38,17 +39,17 @@ def _import_pfl():
3839

3940

4041
def _knn_sp_kdtree(X, num_neighbors, dist_type, order=0):
41-
kdt = spatial.KDTree(X)
42+
kdt = sps.KDTree(X)
4243
D, NN = kdt.query(X, k=(num_neighbors + 1),
4344
p=_dist_translation['scipy-kdtree'][dist_type])
4445
return NN, D
4546

4647
def _knn_flann(X, num_neighbors, dist_type, order):
47-
pfl = _import_pfl()
4848
# the combination FLANN + max_dist produces incorrect results
4949
# do not allow it
5050
if dist_type == 'max_dist':
5151
raise ValueError('FLANN and max_dist is not supported')
52+
pfl = _import_pfl()
5253
pfl.set_distance_type(dist_type, order=order)
5354
flann = pfl.FLANN()
5455

@@ -60,26 +61,26 @@ def _knn_flann(X, num_neighbors, dist_type, order):
6061
return NN, D
6162

6263
def _radius_sp_kdtree(X, epsilon, dist_type, order=0):
63-
kdt = spatial.KDTree(X)
64+
kdt = sps.KDTree(X)
6465
D, NN = kdt.query(X, k=None, distance_upper_bound=epsilon,
6566
p=_dist_translation['scipy-kdtree'][dist_type])
6667
return NN, D
6768

6869
def _knn_sp_pdist(X, num_neighbors, dist_type, order):
69-
pd = spatial.distance.squareform(
70-
spatial.distance.pdist(X,
71-
_dist_translation['scipy-pdist'][dist_type],
72-
p=order))
70+
pd = sps.distance.squareform(
71+
sps.distance.pdist(X,
72+
metric=_dist_translation['scipy-pdist'][dist_type],
73+
p=order))
7374
pds = np.sort(pd)[:, 0:num_neighbors+1]
7475
pdi = pd.argsort()[:, 0:num_neighbors+1]
7576
return pdi, pds
7677

7778
def _radius_sp_pdist(X, epsilon, dist_type, order):
7879
N, dim = np.shape(X)
79-
pd = spatial.distance.squareform(
80-
spatial.distance.pdist(X,
81-
_dist_translation['scipy-pdist'][dist_type],
82-
p=order))
80+
pd = sps.distance.squareform(
81+
sps.distance.pdist(X,
82+
metric=_dist_translation['scipy-pdist'][dist_type],
83+
p=order))
8384
pdf = pd < epsilon
8485
D = []
8586
NN = []
@@ -92,16 +93,25 @@ def _radius_sp_pdist(X, epsilon, dist_type, order):
9293
return NN, D
9394

9495
def _radius_flann(X, epsilon, dist_type, order=0):
95-
pfl = _import_pfl()
96+
N, dim = np.shape(X)
9697
# the combination FLANN + max_dist produces incorrect results
9798
# do not allow it
9899
if dist_type == 'max_dist':
99100
raise ValueError('FLANN and max_dist is not supported')
101+
102+
pfl = _import_pfl()
100103
pfl.set_distance_type(dist_type, order=order)
101104
flann = pfl.FLANN()
102105
flann.build_index(X)
103106

107+
D = []
108+
NN = []
109+
for k in range(N):
110+
nn, d = flann.nn_radius(X[k, :], epsilon)
111+
D.append(d)
112+
NN.append(nn)
104113
flann.delete_index()
114+
return NN, D
105115

106116
class NNGraph(Graph):
107117
r"""Nearest-neighbor graph from given point cloud.

pygsp/tests/test_graphs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,9 @@ def test_nngraph(self):
192192
NNtype='knn', backend=cur_backend,
193193
dist_type=dist_type)
194194
else:
195-
if cur_backend != 'flann':
196-
graphs.NNGraph(Xin, NNtype='radius',
197-
backend=cur_backend,
198-
dist_type=dist_type, order=order)
195+
graphs.NNGraph(Xin, NNtype='radius',
196+
backend=cur_backend,
197+
dist_type=dist_type, order=order)
199198
graphs.NNGraph(Xin, NNtype='knn',
200199
backend=cur_backend,
201200
dist_type=dist_type, order=order)

0 commit comments

Comments
 (0)