Skip to content

Commit 0ef56e9

Browse files
authored
fix plot_roc_curve, plot_ks_statistic, and plot_precision_recall_curve when passed Python lists instead of Numpy arrays (#32)
1 parent f25825c commit 0ef56e9

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

scikitplot/plotters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves', curves=('micro', 'macro
160160
:align: center
161161
:alt: ROC Curves
162162
"""
163+
y_true = np.array(y_true)
164+
y_probas = np.array(y_probas)
163165

164166
if 'micro' not in curves and 'macro' not in curves and 'each_class' not in curves:
165167
raise ValueError('Invalid argument for curves as it only takes "micro", "macro", or "each_class"')
@@ -282,6 +284,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', ax=None, figs
282284
:align: center
283285
:alt: KS Statistic
284286
"""
287+
y_true = np.array(y_true)
288+
y_probas = np.array(y_probas)
289+
285290
classes = np.unique(y_true)
286291
if len(classes) != 2:
287292
raise ValueError('Cannot calculate KS statistic for data with '
@@ -359,6 +364,9 @@ def plot_precision_recall_curve(y_true, y_probas, title='Precision-Recall Curve'
359364
:align: center
360365
:alt: Precision Recall Curve
361366
"""
367+
y_true = np.array(y_true)
368+
y_probas = np.array(y_probas)
369+
362370
classes = np.unique(y_true)
363371
probas = y_probas
364372

scikitplot/tests/test_classifiers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.exceptions import NotFittedError
1010
import numpy as np
1111
import matplotlib.pyplot as plt
12+
import scikitplot.plotters as skplt
1213

1314

1415
def convert_labels_into_string(y_true):
@@ -201,6 +202,9 @@ def test_ax(self):
201202
out_ax = clf.plot_confusion_matrix(self.X, self.y, ax=ax)
202203
assert ax is out_ax
203204

205+
def test_array_like(self):
206+
ax = skplt.plot_confusion_matrix([0, 1], [1, 0])
207+
204208

205209
class TestPlotROCCurve(unittest.TestCase):
206210
def setUp(self):
@@ -272,6 +276,9 @@ def test_invalid_curve_arg(self):
272276
self.assertRaises(ValueError, clf.plot_roc_curve, self.X, self.y,
273277
curves='zzz')
274278

279+
def test_array_like(self):
280+
ax = skplt.plot_roc_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]])
281+
275282

276283
class TestPlotKSStatistic(unittest.TestCase):
277284
def setUp(self):
@@ -333,6 +340,9 @@ def test_ax(self):
333340
out_ax = clf.plot_ks_statistic(self.X, self.y, ax=ax)
334341
assert ax is out_ax
335342

343+
def test_array_like(self):
344+
ax = skplt.plot_ks_statistic([0, 1], [[0.8, 0.2], [0.2, 0.8]])
345+
336346

337347
class TestPlotPrecisionRecall(unittest.TestCase):
338348
def setUp(self):
@@ -403,6 +413,9 @@ def test_invalid_curve_arg(self):
403413
self.assertRaises(ValueError, clf.plot_precision_recall_curve, self.X, self.y,
404414
curves='zzz')
405415

416+
def test_array_like(self):
417+
ax = skplt.plot_precision_recall_curve([0, 1], [[0.8, 0.2], [0.2, 0.8]])
418+
406419

407420
class TestFeatureImportances(unittest.TestCase):
408421
def setUp(self):

0 commit comments

Comments
 (0)