Skip to content

Commit 505e456

Browse files
committed
nngraph: fix radius cKDTree (PR #21)
1 parent 695272b commit 505e456

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

pygsp/graphs/nngraphs/nngraph.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,16 @@ def _radius_sp_kdtree(features, radius, metric, order):
107107
def _radius_sp_ckdtree(features, radius, metric, order):
108108
p = order if metric == 'minkowski' else _metrics['scipy-ckdtree'][metric]
109109
n_vertices, _ = features.shape
110-
kdt = spatial.cKDTree(features)
111-
nn = kdt.query_ball_point(features, r=radius, p=p, n_jobs=-1)
112-
D = []
113-
NN = []
114-
for k in range(n_vertices):
115-
x = np.tile(features[k, :], (len(nn[k]), 1))
116-
d = np.linalg.norm(x - features[nn[k], :],
117-
ord=_metrics['scipy-ckdtree'][metric],
118-
axis=1)
119-
nidx = d.argsort()
120-
NN.append(np.take(nn[k], nidx))
121-
D.append(np.sort(d))
122-
return NN, D
110+
tree = spatial.cKDTree(features)
111+
D, NN = tree.query(features, k=n_vertices, distance_upper_bound=radius,
112+
p=p, n_jobs=-1)
113+
distances = []
114+
neighbors = []
115+
for d, n in zip(D, NN):
116+
mask = (d != np.inf)
117+
distances.append(d[mask])
118+
neighbors.append(n[mask])
119+
return neighbors, distances
123120

124121

125122
def _knn_sp_pdist(features, num_neighbors, metric, order):

0 commit comments

Comments
 (0)