Skip to content

Commit 720646e

Browse files
committed
simplify test_nngraph (PR #21)
1 parent 9b5d8c0 commit 720646e

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

pygsp/tests/test_graphs.py

+21-36
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def test_set_coordinates(self):
348348
self.assertRaises(ValueError, G.set_coordinates, 'invalid')
349349

350350
def test_nngraph(self, n_vertices=30):
351+
"""Test all the combinations of metric, kind, backend."""
351352
features = np.random.RandomState(42).normal(size=(n_vertices, 3))
352353
metrics = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
353354
backends = ['scipy-kdtree', 'scipy-ckdtree', 'scipy-pdist', 'nmslib',
@@ -356,46 +357,30 @@ def test_nngraph(self, n_vertices=30):
356357

357358
for backend in backends:
358359
for metric in metrics:
359-
if ((backend == 'flann' and metric == 'max_dist') or
360-
(backend == 'nmslib' and metric == 'minkowski')):
361-
self.assertRaises(ValueError, graphs.NNGraph, features,
362-
kind='knn', backend=backend,
363-
metric=metric)
364-
self.assertRaises(ValueError, graphs.NNGraph, features,
365-
kind='radius', backend=backend,
366-
metric=metric)
367-
else:
368-
if backend == 'nmslib':
369-
self.assertRaises(ValueError, graphs.NNGraph, features,
370-
kind='radius', backend=backend,
371-
metric=metric, order=order)
360+
for kind in ['knn', 'radius']:
361+
params = dict(features=features, metric=metric,
362+
order=order, kind=kind, backend=backend)
363+
# Unsupported combinations.
364+
if backend == 'flann' and metric == 'max_dist':
365+
self.assertRaises(ValueError, graphs.NNGraph, **params)
366+
elif backend == 'nmslib' and metric == 'minkowski':
367+
self.assertRaises(ValueError, graphs.NNGraph, **params)
368+
elif backend == 'nmslib' and kind == 'radius':
369+
self.assertRaises(ValueError, graphs.NNGraph, **params)
372370
else:
373-
graphs.NNGraph(features, kind='radius',
374-
backend=backend,
375-
metric=metric, order=order)
376-
graphs.NNGraph(features, kind='knn',
377-
backend=backend,
378-
metric=metric, order=order)
379-
graphs.NNGraph(features, kind='knn',
380-
backend=backend,
381-
metric=metric, order=order,
382-
center=False)
383-
graphs.NNGraph(features, kind='knn',
384-
backend=backend,
385-
metric=metric, order=order,
386-
rescale=False)
387-
graphs.NNGraph(features, kind='knn',
388-
backend=backend,
389-
metric=metric, order=order,
390-
rescale=False, center=False)
371+
graphs.NNGraph(**params, center=False)
372+
graphs.NNGraph(**params, rescale=False)
373+
graphs.NNGraph(**params, center=False, rescale=False)
374+
375+
# Invalid parameters.
376+
self.assertRaises(ValueError, graphs.NNGraph, features,
377+
metric='invalid')
391378
self.assertRaises(ValueError, graphs.NNGraph, features,
392-
kind='invalid', backend=backend,
393-
metric=metric)
379+
kind='invalid')
394380
self.assertRaises(ValueError, graphs.NNGraph, features,
395-
kind='knn', backend='invalid',
396-
metric=metric)
381+
backend='invalid')
397382
self.assertRaises(ValueError, graphs.NNGraph, features,
398-
kind='knn', k=n_vertices+1)
383+
kind='knn', k=n_vertices+1)
399384

400385
def test_nngraph_consistency(self):
401386
features = np.arange(90).reshape(30, 3)

0 commit comments

Comments
 (0)