Skip to content

Commit d1cf697

Browse files
authored
Release version 1.1.0 (#38)
New release provides an average recall metric.
1 parent 02b890d commit d1cf697

File tree

5 files changed

+64
-87
lines changed

5 files changed

+64
-87
lines changed

keras_metrics/__init__.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras_metrics import casts
44

55

6-
__version__ = "1.0.0"
6+
__version__ = "1.1.0"
77

88

99
def metric_fn(cls, cast_strategy):
@@ -25,18 +25,6 @@ def fn(label=0, **kwargs):
2525
sparse_categorical_metric = partial(
2626
metric_fn, cast_strategy=casts.sparse_categorical)
2727

28-
binary_average_metric = partial(
29-
metric_fn, cast_strategy=casts.binary_argmax
30-
)
31-
32-
categorical_average_metric = partial(
33-
metric_fn, cast_strategy=casts.argmax
34-
)
35-
36-
sparse_categorical_average_metric = partial(
37-
metric_fn, cast_strategy=casts.sparse_argmax
38-
)
39-
4028

4129
binary_true_positive = binary_metric(m.true_positive)
4230
binary_true_negative = binary_metric(m.true_negative)
@@ -45,7 +33,7 @@ def fn(label=0, **kwargs):
4533
binary_precision = binary_metric(m.precision)
4634
binary_recall = binary_metric(m.recall)
4735
binary_f1_score = binary_metric(m.f1_score)
48-
binary_average_recall = binary_average_metric(m.average_recall)
36+
binary_average_recall = binary_metric(m.average_recall)
4937

5038

5139
categorical_true_positive = categorical_metric(m.true_positive)
@@ -55,7 +43,7 @@ def fn(label=0, **kwargs):
5543
categorical_precision = categorical_metric(m.precision)
5644
categorical_recall = categorical_metric(m.recall)
5745
categorical_f1_score = categorical_metric(m.f1_score)
58-
categorical_average_recall = categorical_average_metric(m.average_recall)
46+
categorical_average_recall = categorical_metric(m.average_recall)
5947

6048

6149
sparse_categorical_true_positive = sparse_categorical_metric(m.true_positive)
@@ -65,8 +53,7 @@ def fn(label=0, **kwargs):
6553
sparse_categorical_precision = sparse_categorical_metric(m.precision)
6654
sparse_categorical_recall = sparse_categorical_metric(m.recall)
6755
sparse_categorical_f1_score = sparse_categorical_metric(m.f1_score)
68-
sparse_categorical_average_recall = sparse_categorical_average_metric(
69-
m.average_recall)
56+
sparse_categorical_average_recall = sparse_categorical_metric(m.average_recall)
7057

7158

7259
# For backward compatibility.

keras_metrics/casts.py

-21
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,3 @@ def sparse_categorical(y_true, y_pred, dtype="int32", label=0):
2424
y_pred = K.cast(K.round(y_pred), dtype)
2525

2626
return y_true, y_pred
27-
28-
29-
def binary_argmax(y_true, y_pred, dtype="int32", label=0):
30-
y_true, y_pred = K.squeeze(y_true, axis=-1), K.squeeze(y_pred, axis=-1)
31-
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)
32-
33-
return y_true, y_pred
34-
35-
36-
def argmax(y_true, y_pred, dtype="int32", label=0):
37-
y_true, y_pred = K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)
38-
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)
39-
40-
return y_true, y_pred
41-
42-
43-
def sparse_argmax(y_true, y_pred, dtype="int32", label=0):
44-
y_true, y_pred = K.squeeze(y_true, axis=-1), K.argmax(y_pred, axis=-1)
45-
y_true, y_pred = K.cast(y_true, dtype=dtype), K.cast(y_pred, dtype=dtype)
46-
47-
return y_true, y_pred

keras_metrics/metrics.py

+18-36
Original file line numberDiff line numberDiff line change
@@ -232,51 +232,33 @@ class average_recall(layer):
232232
"""Create a metric for the average recall calculation.
233233
"""
234234

235-
def __init__(self, name="average_recall", classes=2, **kwargs):
235+
def __init__(self, name="average_recall", labels=1, **kwargs):
236236
super(average_recall, self).__init__(name=name, **kwargs)
237237

238-
if classes < 2:
239-
raise ValueError('argument classes must >= 2')
238+
self.labels = labels
240239

241-
self.classes = classes
242-
243-
self.true = K.zeros(classes, dtype="int32")
244-
self.pred = K.zeros(classes, dtype="int32")
240+
self.tp = K.zeros(labels, dtype="int32")
241+
self.fn = K.zeros(labels, dtype="int32")
245242

246243
def reset_states(self):
247-
K.set_value(self.true, [0 for v in range(self.classes)])
248-
K.set_value(self.pred, [0 for v in range(self.classes)])
244+
K.set_value(self.tp, [0]*self.labels)
245+
K.set_value(self.fn, [0]*self.labels)
249246

250247
def __call__(self, y_true, y_pred):
251-
# Cast input
252-
t, p = self.cast(y_true, y_pred, dtype="float64")
253-
254-
# Init a bias matrix
255-
b = K.variable([truediv(1, (v + 1)) for v in range(self.classes)],
256-
dtype="float64")
257-
258-
# Simulate to_categorical operation
259-
t, p = K.expand_dims(t, axis=-1), K.expand_dims(p, axis=-1)
260-
t, p = (t + 1) * b - 1, (p + 1) * b - 1
261-
262-
# Make correct position filled with 1
263-
t, p = K.cast(t, "bool"), K.cast(p, "bool")
264-
t, p = 1 - K.cast(t, "int32"), 1 - K.cast(p, "int32")
265-
266-
t, p = K.transpose(t), K.transpose(p)
248+
y_true = K.cast(K.round(y_true), "int32")
249+
y_pred = K.cast(K.round(y_pred), "int32")
250+
neg_y_pred = 1 - y_pred
267251

268-
# Results for current batch
269-
batch_t = K.sum(t, axis=-1)
270-
batch_p = K.sum(t * p, axis=-1)
252+
tp = K.sum(K.transpose(y_true * y_pred), axis=-1)
253+
fn = K.sum(K.transpose(y_true * neg_y_pred), axis=-1)
271254

272-
# Accumulated results
273-
total_t = self.true * 1 + batch_t
274-
total_p = self.pred * 1 + batch_p
255+
current_tp = K.cast(self.tp + tp, self.epsilon.dtype)
256+
current_fn = K.cast(self.fn + fn, self.epsilon.dtype)
275257

276-
self.add_update(K.update_add(self.true, batch_t))
277-
self.add_update(K.update_add(self.pred, batch_p))
258+
tp_update = K.update_add(self.tp, tp)
259+
fn_update = K.update_add(self.fn, fn)
278260

279-
tp = K.cast(total_p, dtype='float64')
280-
tt = K.cast(total_t, dtype='float64')
261+
self.add_update(tp_update, inputs=[y_true, y_pred])
262+
self.add_update(fn_update, inputs=[y_true, y_pred])
281263

282-
return K.mean(truediv(tp, (tt + self.epsilon)))
264+
return K.mean(truediv(current_tp, current_tp + current_fn + self.epsilon))

tests/test_average_recall.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import keras
2+
import keras.utils
3+
import keras_metrics as km
4+
import numpy
5+
import unittest
6+
7+
8+
class TestAverageRecall(unittest.TestCase):
9+
10+
def create_samples(self, n, labels=1):
11+
x = numpy.random.uniform(0, numpy.pi/2, (n, labels))
12+
y = numpy.random.randint(labels, size=(n, 1))
13+
return x, keras.utils.to_categorical(y)
14+
15+
def test_average_recall(self):
16+
model = keras.models.Sequential()
17+
model.add(keras.layers.Activation(keras.backend.sin))
18+
model.add(keras.layers.Activation(keras.backend.abs))
19+
model.add(keras.layers.Softmax())
20+
model.compile(optimizer="sgd",
21+
loss="categorical_crossentropy",
22+
metrics=[
23+
km.categorical_recall(label=0),
24+
km.categorical_recall(label=1),
25+
km.categorical_recall(label=2),
26+
km.categorical_average_recall(labels=3),
27+
])
28+
29+
x, y = self.create_samples(10000, labels=3)
30+
31+
model.fit(x, y, epochs=10, batch_size=100)
32+
metrics = model.evaluate(x, y, batch_size=100)[1:]
33+
34+
r0, r1, r2 = metrics[0:3]
35+
average_recall = metrics[3]
36+
37+
expected_recall = (r0+r1+r2)/3.0
38+
self.assertAlmostEqual(expected_recall, average_recall, places=3)
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

tests/test_metrics.py

-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import keras
22
import keras.backend
33
import keras.utils
4-
import keras.regularizers
54
import keras_metrics as km
65
import itertools
76
import numpy
@@ -21,7 +20,6 @@ class TestMetrics(unittest.TestCase):
2120
km.binary_precision,
2221
km.binary_recall,
2322
km.binary_f1_score,
24-
km.binary_average_recall
2523
]
2624

2725
categorical_metrics = [
@@ -32,7 +30,6 @@ class TestMetrics(unittest.TestCase):
3230
km.categorical_precision,
3331
km.categorical_recall,
3432
km.categorical_f1_score,
35-
km.categorical_average_recall,
3633
]
3734

3835
sparse_categorical_metrics = [
@@ -43,7 +40,6 @@ class TestMetrics(unittest.TestCase):
4340
km.sparse_categorical_precision,
4441
km.sparse_categorical_recall,
4542
km.sparse_categorical_f1_score,
46-
km.sparse_categorical_average_recall,
4743
]
4844

4945
def create_binary_samples(self, n):
@@ -63,9 +59,6 @@ def create_model(self, outputs, loss, metrics_fns):
6359
model.add(keras.layers.Activation(keras.backend.sin))
6460
model.add(keras.layers.Activation(keras.backend.abs))
6561
model.add(keras.layers.Lambda(lambda x: K.concatenate([x]*outputs)))
66-
scale = [v + 1 for v in range(outputs)]
67-
model.add(keras.layers.Lambda(lambda x: (0.5 - x) * scale + 1))
68-
model.add(keras.layers.Softmax())
6962
model.compile(optimizer="sgd",
7063
loss=loss,
7164
metrics=self.create_metrics(metrics_fns))
@@ -132,14 +125,10 @@ def assert_metrics(self, model, samples_fn):
132125
precision = metrics[4]
133126
recall = metrics[5]
134127
f1 = metrics[6]
135-
average_recall = metrics[7]
136128

137129
expected_precision = tp_val / (tp_val + fp_val)
138130
expected_recall = tp_val / (tp_val + fn_val)
139131

140-
expected_average_recall = (
141-
expected_recall + (tn_val / (fp_val + tn_val))) / 2
142-
143132
f1_divident = (expected_precision*expected_recall)
144133
f1_divisor = (expected_precision+expected_recall)
145134
expected_f1 = (2 * f1_divident / f1_divisor)
@@ -155,8 +144,6 @@ def assert_metrics(self, model, samples_fn):
155144
self.assertAlmostEqual(expected_precision, precision, places=places)
156145
self.assertAlmostEqual(expected_recall, recall, places=places)
157146
self.assertAlmostEqual(expected_f1, f1, places=places)
158-
self.assertAlmostEqual(expected_average_recall,
159-
average_recall, places=places)
160147

161148
def test_binary_metrics(self):
162149
model = self.create_model(1, "binary_crossentropy",

0 commit comments

Comments
 (0)