Skip to content

Commit 6f473fa

Browse files
committed
implement nn graph using pdist using radius
1 parent 62fc0ce commit 6f473fa

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

pygsp/graphs/nngraphs/nngraph.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,43 @@ def _radius_sp_kdtree(X, epsilon, dist_type, order=0):
6565
p=_dist_translation['scipy-kdtree'][dist_type])
6666
return NN, D
6767

68-
def _knn_sp_pdist(X, num_neighbors, dist_type, _order):
68+
def _knn_sp_pdist(X, num_neighbors, dist_type, order):
6969
pd = spatial.distance.squareform(
7070
spatial.distance.pdist(X,
7171
_dist_translation['scipy-pdist'][dist_type],
72-
p=_order))
72+
p=order))
7373
pds = np.sort(pd)[:, 0:num_neighbors+1]
7474
pdi = pd.argsort()[:, 0:num_neighbors+1]
7575
return pdi, pds
7676

77-
def _radius_sp_pdist(_X, _epsilon, _dist_type, order=0):
78-
raise NotImplementedError()
77+
def _radius_sp_pdist(X, epsilon, dist_type, order):
78+
N, dim = np.shape(X)
79+
pd = spatial.distance.squareform(
80+
spatial.distance.pdist(X,
81+
_dist_translation['scipy-pdist'][dist_type],
82+
p=order))
83+
pdf = pd < epsilon
84+
D = []
85+
NN = []
86+
for k in range(N):
87+
v = pd[k, pdf[k, :]]
88+
# use the same conventions as in scipy.distance.kdtree
89+
NN.append(v.argsort())
90+
D.append(np.sort(v))
91+
92+
return NN, D
7993

80-
def _radius_flann(_X, _epsilon, _dist_type, order=0):
81-
raise NotImplementedError()
94+
def _radius_flann(X, epsilon, dist_type, order=0):
95+
pfl = _import_pfl()
96+
# the combination FLANN + max_dist produces incorrect results
97+
# do not allow it
98+
if dist_type == 'max_dist':
99+
raise ValueError('FLANN and max_dist is not supported')
100+
pfl.set_distance_type(dist_type, order=order)
101+
flann = pfl.FLANN()
102+
flann.build_index(X)
103+
104+
flann.delete_index()
82105

83106
class NNGraph(Graph):
84107
r"""Nearest-neighbor graph from given point cloud.

pygsp/tests/test_graphs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_nngraph(self):
192192
NNtype='knn', backend=cur_backend,
193193
dist_type=dist_type)
194194
else:
195-
if cur_backend == 'scipy-kdtree':
195+
if cur_backend != 'flann':
196196
graphs.NNGraph(Xin, NNtype='radius',
197197
backend=cur_backend,
198198
dist_type=dist_type, order=order)

0 commit comments

Comments
 (0)