|
9 | 9 | from sklearn.exceptions import NotFittedError |
10 | 10 | import numpy as np |
11 | 11 | import matplotlib.pyplot as plt |
| 12 | +import scikitplot.plotters as skplt |
12 | 13 |
|
13 | 14 |
|
14 | 15 | def convert_labels_into_string(y_true): |
@@ -201,6 +202,9 @@ def test_ax(self): |
201 | 202 | out_ax = clf.plot_confusion_matrix(self.X, self.y, ax=ax) |
202 | 203 | assert ax is out_ax |
203 | 204 |
|
| 205 | + def test_array_like(self): |
| 206 | + ax = skplt.plot_confusion_matrix([0, 1], [1, 0]) |
| 207 | + |
204 | 208 |
|
205 | 209 | class TestPlotROCCurve(unittest.TestCase): |
206 | 210 | def setUp(self): |
@@ -272,6 +276,9 @@ def test_invalid_curve_arg(self): |
272 | 276 | self.assertRaises(ValueError, clf.plot_roc_curve, self.X, self.y, |
273 | 277 | curves='zzz') |
274 | 278 |
|
| 279 | + def test_array_like(self): |
| 280 | + ax = skplt.plot_roc_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) |
| 281 | + |
275 | 282 |
|
276 | 283 | class TestPlotKSStatistic(unittest.TestCase): |
277 | 284 | def setUp(self): |
@@ -333,6 +340,9 @@ def test_ax(self): |
333 | 340 | out_ax = clf.plot_ks_statistic(self.X, self.y, ax=ax) |
334 | 341 | assert ax is out_ax |
335 | 342 |
|
| 343 | + def test_array_like(self): |
| 344 | + ax = skplt.plot_ks_statistic([0, 1], [[0.8, 0.2], [0.2, 0.8]]) |
| 345 | + |
336 | 346 |
|
337 | 347 | class TestPlotPrecisionRecall(unittest.TestCase): |
338 | 348 | def setUp(self): |
@@ -403,6 +413,9 @@ def test_invalid_curve_arg(self): |
403 | 413 | self.assertRaises(ValueError, clf.plot_precision_recall_curve, self.X, self.y, |
404 | 414 | curves='zzz') |
405 | 415 |
|
| 416 | + def test_array_like(self): |
| 417 | + ax = skplt.plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]]) |
| 418 | + |
406 | 419 |
|
407 | 420 | class TestFeatureImportances(unittest.TestCase): |
408 | 421 | def setUp(self): |
|
0 commit comments