Skip to content

Commit 90e5e4f

Browse files
aronhoffybubnov
authored andcommitted
Added support for sparse labels (#30)
This patch provides support for sparse labels.
1 parent a16a8c2 commit 90e5e4f

File tree

2 files changed

+69
-40
lines changed

2 files changed

+69
-40
lines changed

keras_metrics/metrics.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
class layer(Layer):
88

9-
def __init__(self, label=None, name=None, **kwargs):
9+
def __init__(self, label=None, name=None, sparse=False, **kwargs):
1010
super(layer, self).__init__(name=name, **kwargs)
1111
self.stateful = True
1212
self.epsilon = K.constant(K.epsilon(), dtype="float64")
1313
self.__name__ = name
14+
self.sparse = sparse
1415

1516
# If layer metric is explicitly created to evaluate specified class,
1617
# then use a binary transformation of the output arrays, otherwise
@@ -30,7 +31,10 @@ def _binary(self, y_true, y_pred, dtype, label=0):
3031
column = slice(label, label+1)
3132

3233
# Slice a column of the output array.
33-
y_true = y_true[...,column]
34+
if self.sparse:
35+
y_true = K.cast(K.equal(y_true, label), dtype=dtype)
36+
else:
37+
y_true = y_true[...,column]
3438
y_pred = y_pred[...,column]
3539
return self._categorical(y_true, y_pred, dtype)
3640

@@ -65,8 +69,8 @@ class true_positive(layer):
6569
positive class.
6670
"""
6771

68-
def __init__(self, name="true_positive", **kwargs):
69-
super(true_positive, self).__init__(name=name, **kwargs)
72+
def __init__(self, name="true_positive", label=None, sparse=False, **kwargs):
73+
super(true_positive, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
7074
self.tp = K.variable(0, dtype="int32")
7175

7276
def reset_states(self):
@@ -92,8 +96,8 @@ class true_negative(layer):
9296
negative class.
9397
"""
9498

95-
def __init__(self, name="true_negative", **kwargs):
96-
super(true_negative, self).__init__(name=name, **kwargs)
99+
def __init__(self, name="true_negative", label=None, sparse=False, **kwargs):
100+
super(true_negative, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
97101
self.tn = K.variable(0, dtype="int32")
98102

99103
def reset_states(self):
@@ -122,8 +126,8 @@ class false_negative(layer):
122126
negative class.
123127
"""
124128

125-
def __init__(self, name="false_negative", **kwargs):
126-
super(false_negative, self).__init__(name=name, **kwargs)
129+
def __init__(self, name="false_negative", label=None, sparse=False, **kwargs):
130+
super(false_negative, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
127131
self.fn = K.variable(0, dtype="int32")
128132

129133
def reset_states(self):
@@ -150,8 +154,8 @@ class false_positive(layer):
150154
positive class.
151155
"""
152156

153-
def __init__(self, name="false_positive", **kwargs):
154-
super(false_positive, self).__init__(name=name, **kwargs)
157+
def __init__(self, name="false_positive", label=None, sparse=False, **kwargs):
158+
super(false_positive, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
155159
self.fp = K.variable(0, dtype="int32")
156160

157161
def reset_states(self):
@@ -177,11 +181,11 @@ class recall(layer):
177181
Recall measures proportion of actual positives that was identified correctly.
178182
"""
179183

180-
def __init__(self, name="recall", **kwargs):
181-
super(recall, self).__init__(name=name, **kwargs)
184+
def __init__(self, name="recall", label=None, sparse=False, **kwargs):
185+
super(recall, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
182186

183-
self.tp = true_positive()
184-
self.fn = false_negative()
187+
self.tp = true_positive(label=label, sparse=sparse)
188+
self.fn = false_negative(label=label, sparse=sparse)
185189

186190
def reset_states(self):
187191
"""Reset the state of the metrics."""
@@ -208,11 +212,11 @@ class precision(layer):
208212
actually correct.
209213
"""
210214

211-
def __init__(self, name="precision", **kwargs):
212-
super(precision, self).__init__(name=name, **kwargs)
215+
def __init__(self, name="precision", label=None, sparse=False, **kwargs):
216+
super(precision, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
213217

214-
self.tp = true_positive()
215-
self.fp = false_positive()
218+
self.tp = true_positive(label=label, sparse=sparse)
219+
self.fp = false_positive(label=label, sparse=sparse)
216220

217221
def reset_states(self):
218222
"""Reset the state of the metrics."""
@@ -238,11 +242,11 @@ class f1_score(layer):
238242
The F1 score is the harmonic mean of precision and recall.
239243
"""
240244

241-
def __init__(self, name="f1_score", label=None, **kwargs):
242-
super(f1_score, self).__init__(name=name, label=label, **kwargs)
245+
def __init__(self, name="f1_score", label=None, sparse=False, **kwargs):
246+
super(f1_score, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
243247

244-
self.precision = precision(label=label)
245-
self.recall = recall(label=label)
248+
self.precision = precision(label=label, sparse=sparse)
249+
self.recall = recall(label=label, sparse=sparse)
246250

247251
def reset_states(self):
248252
"""Reset the state of the metrics."""

tests/test_metrics.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,52 @@
99

1010
class TestMetrics(unittest.TestCase):
1111

12+
def __init__(self, methodName, sparse=False):
13+
super(TestMetrics, self).__init__(methodName=methodName)
14+
self.sparse = sparse
15+
1216
def setUp(self):
13-
tp = keras_metrics.true_positive()
14-
tn = keras_metrics.true_negative()
15-
fp = keras_metrics.false_positive()
16-
fn = keras_metrics.false_negative()
17+
tp = keras_metrics.true_positive(sparse=self.sparse)
18+
tn = keras_metrics.true_negative(sparse=self.sparse)
19+
fp = keras_metrics.false_positive(sparse=self.sparse)
20+
fn = keras_metrics.false_negative(sparse=self.sparse)
1721

18-
precision = keras_metrics.precision()
19-
recall = keras_metrics.recall()
20-
f1 = keras_metrics.f1_score()
22+
precision = keras_metrics.precision(sparse=self.sparse)
23+
recall = keras_metrics.recall(sparse=self.sparse)
24+
f1 = keras_metrics.f1_score(sparse=self.sparse)
2125

2226
self.model = keras.models.Sequential()
2327
self.model.add(keras.layers.Activation(keras.backend.sin))
2428
self.model.add(keras.layers.Activation(keras.backend.abs))
2529

30+
if self.sparse:
31+
loss = "sparse_categorical_crossentropy"
32+
else:
33+
loss = "binary_crossentropy"
34+
2635
self.model.compile(optimizer="sgd",
27-
loss="binary_crossentropy",
36+
loss=loss,
2837
metrics=[tp, tn, fp, fn, precision, recall, f1])
2938

3039
def samples(self, n):
31-
x = numpy.random.uniform(0, numpy.pi/2, (n, 1))
32-
y = numpy.random.randint(2, size=(n, 1))
40+
if self.sparse:
41+
categories = 2
42+
x = numpy.random.uniform(0, numpy.pi/2, (n, categories))
43+
y = numpy.random.randint(categories, size=(n, 1))
44+
else:
45+
x = numpy.random.uniform(0, numpy.pi/2, (n, 1))
46+
y = numpy.random.randint(2, size=(n, 1))
3347
return x, y
3448

3549
def test_save_load(self):
3650
custom_objects = {
37-
"true_positive": keras_metrics.true_positive(),
38-
"true_negative": keras_metrics.true_negative(),
39-
"false_positive": keras_metrics.false_positive(),
40-
"false_negative": keras_metrics.false_negative(),
41-
"precision": keras_metrics.precision(),
42-
"recall": keras_metrics.recall(),
43-
"f1_score": keras_metrics.f1_score(),
51+
"true_positive": keras_metrics.true_positive(sparse=self.sparse),
52+
"true_negative": keras_metrics.true_negative(sparse=self.sparse),
53+
"false_positive": keras_metrics.false_positive(sparse=self.sparse),
54+
"false_negative": keras_metrics.false_negative(sparse=self.sparse),
55+
"precision": keras_metrics.precision(sparse=self.sparse),
56+
"recall": keras_metrics.recall(sparse=self.sparse),
57+
"f1_score": keras_metrics.f1_score(sparse=self.sparse),
4458
"sin": keras.backend.sin,
4559
"abs": keras.backend.abs,
4660
}
@@ -97,5 +111,16 @@ def test_metrics(self):
97111
self.assertAlmostEqual(expected_f1, f1, places=places)
98112

99113

114+
def suite():
115+
s = unittest.TestSuite()
116+
s.addTests(TestMetrics(methodName=method, sparse=sparse)
117+
for method in unittest.defaultTestLoader.getTestCaseNames(TestMetrics)
118+
for sparse in (False, True))
119+
return s
120+
121+
100122
if __name__ == "__main__":
101-
unittest.main()
123+
import sys
124+
result = unittest.TextTestRunner().run(suite())
125+
sys.exit(not result.wasSuccessful())
126+
# unittest.main()

0 commit comments

Comments
 (0)