@@ -348,6 +348,7 @@ def test_set_coordinates(self):
348
348
self .assertRaises (ValueError , G .set_coordinates , 'invalid' )
349
349
350
350
def test_nngraph (self , n_vertices = 30 ):
351
+ """Test all the combinations of metric, kind, backend."""
351
352
features = np .random .RandomState (42 ).normal (size = (n_vertices , 3 ))
352
353
metrics = ['euclidean' , 'manhattan' , 'max_dist' , 'minkowski' ]
353
354
backends = ['scipy-kdtree' , 'scipy-ckdtree' , 'scipy-pdist' , 'nmslib' ,
@@ -356,46 +357,30 @@ def test_nngraph(self, n_vertices=30):
356
357
357
358
for backend in backends :
358
359
for metric in metrics :
359
- if ((backend == 'flann' and metric == 'max_dist' ) or
360
- (backend == 'nmslib' and metric == 'minkowski' )):
361
- self .assertRaises (ValueError , graphs .NNGraph , features ,
362
- kind = 'knn' , backend = backend ,
363
- metric = metric )
364
- self .assertRaises (ValueError , graphs .NNGraph , features ,
365
- kind = 'radius' , backend = backend ,
366
- metric = metric )
367
- else :
368
- if backend == 'nmslib' :
369
- self .assertRaises (ValueError , graphs .NNGraph , features ,
370
- kind = 'radius' , backend = backend ,
371
- metric = metric , order = order )
360
+ for kind in ['knn' , 'radius' ]:
361
+ params = dict (features = features , metric = metric ,
362
+ order = order , kind = kind , backend = backend )
363
+ # Unsupported combinations.
364
+ if backend == 'flann' and metric == 'max_dist' :
365
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
366
+ elif backend == 'nmslib' and metric == 'minkowski' :
367
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
368
+ elif backend == 'nmslib' and kind == 'radius' :
369
+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
372
370
else :
373
- graphs .NNGraph (features , kind = 'radius' ,
374
- backend = backend ,
375
- metric = metric , order = order )
376
- graphs .NNGraph (features , kind = 'knn' ,
377
- backend = backend ,
378
- metric = metric , order = order )
379
- graphs .NNGraph (features , kind = 'knn' ,
380
- backend = backend ,
381
- metric = metric , order = order ,
382
- center = False )
383
- graphs .NNGraph (features , kind = 'knn' ,
384
- backend = backend ,
385
- metric = metric , order = order ,
386
- rescale = False )
387
- graphs .NNGraph (features , kind = 'knn' ,
388
- backend = backend ,
389
- metric = metric , order = order ,
390
- rescale = False , center = False )
371
+ graphs .NNGraph (** params , center = False )
372
+ graphs .NNGraph (** params , rescale = False )
373
+ graphs .NNGraph (** params , center = False , rescale = False )
374
+
375
+ # Invalid parameters.
376
+ self .assertRaises (ValueError , graphs .NNGraph , features ,
377
+ metric = 'invalid' )
391
378
self .assertRaises (ValueError , graphs .NNGraph , features ,
392
- kind = 'invalid' , backend = backend ,
393
- metric = metric )
379
+ kind = 'invalid' )
394
380
self .assertRaises (ValueError , graphs .NNGraph , features ,
395
- kind = 'knn' , backend = 'invalid' ,
396
- metric = metric )
381
+ backend = 'invalid' )
397
382
self .assertRaises (ValueError , graphs .NNGraph , features ,
398
- kind = 'knn' , k = n_vertices + 1 )
383
+ kind = 'knn' , k = n_vertices + 1 )
399
384
400
385
def test_nngraph_consistency (self ):
401
386
features = np .arange (90 ).reshape (30 , 3 )
0 commit comments