Skip to content

Commit 2469a27

Browse files
authored
Independent metrics (#31)
This patch separates metrics by the task they target to measure. So now, there are tree types of metrics: - binary metrics - categorical metrics - sparse categorical metrics To preserve backward compatibility with the older code, old names "precision", "recall", etc. refer to the "binary_precision", "binary_recall", etc. respectively. These changes require to bump major version of the package.
1 parent 90e5e4f commit 2469a27

File tree

5 files changed

+267
-139
lines changed

5 files changed

+267
-139
lines changed

README.md

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,73 @@ pip install keras-metrics
1818
The usage of the package is simple:
1919
```py
2020
import keras
21-
import keras_metrics
21+
import keras_metrics as km
2222

2323
model = models.Sequential()
2424
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
2525
model.add(keras.layers.Dense(1, activation="softmax"))
2626

2727
model.compile(optimizer="sgd",
2828
loss="binary_crossentropy",
29-
metrics=[keras_metrics.precision(), keras_metrics.recall()])
29+
metrics=[km.binary_precision(), km.binary_recall()])
3030
```
3131

3232
Similar configuration for multi-label binary crossentropy:
3333
```py
3434
import keras
35-
import keras_metrics
35+
import keras_metrics as km
3636

3737
model = models.Sequential()
3838
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
3939
model.add(keras.layers.Dense(2, activation="softmax"))
4040

4141
# Calculate precision for the second label.
42-
precision = keras_metrics.precision(label=1)
42+
precision = km.binary_precision(label=1)
4343

4444
# Calculate recall for the first label.
45-
recall = keras_metrics.recall(label=0)
45+
recall = km.recall(label=0)
4646

4747
model.compile(optimizer="sgd",
4848
loss="binary_crossentropy",
4949
metrics=[precision, recall])
5050
```
5151

52+
Keras metrics package also supports metrics for categorical crossentropy and
53+
sparse categorical crossentropy:
54+
```py
55+
import keras_metrics as km
56+
57+
c_precision = km.categorical_precision()
58+
sc_precision = km.sparse_categorical_precision()
59+
60+
# ...
61+
```
62+
63+
## Tensorflow Keras
64+
65+
Tensorflow library provides the ```keras``` package as parts of its API, in
66+
order to use ```keras_metrics``` with Tensorflow Keras, you are advised to
67+
perform model training with initialized global variables:
68+
```py
69+
import numpy as np
70+
import keras_metrics as km
71+
import tensorflow as tf
72+
import tensorflow.keras as keras
73+
74+
model = keras.Sequential()
75+
model.add(keras.layers.Dense(1, activation="softmax"))
76+
model.compile(optimizer="sgd",
77+
loss="binary_crossentropy",
78+
metrics=[km.binary_true_positive()])
79+
80+
x = np.array([[0], [1], [0], [1]])
81+
y = np.array([1, 0, 1, 0]
82+
83+
# Wrap model.fit into the session with global
84+
# variables initialization.
85+
with tf.Session() as s:
86+
s.run(tf.global_variables_initializer())
87+
model.fit(x=x, y=y)
88+
```
89+
5290
[BuildStatus]: https://travis-ci.org/netrack/keras-metrics.svg?branch=master

keras_metrics/__init__.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,63 @@
1+
from functools import partial
2+
from keras_metrics import metrics as m
3+
from keras_metrics import casts
4+
5+
16
__version__ = "0.0.7"
27

38

4-
from keras_metrics.metrics import *
9+
def metric_fn(cls, cast_strategy):
10+
def fn(label=0, **kwargs):
11+
metric = cls(label=label, cast_strategy=cast_strategy, **kwargs)
12+
metric.__name__ = "%s_%s" % (cast_strategy.__name__, cls.__name__)
13+
return metric
14+
return fn
15+
16+
17+
binary_metric = partial(
18+
metric_fn, cast_strategy=casts.binary)
19+
20+
21+
categorical_metric = partial(
22+
metric_fn, cast_strategy=casts.categorical)
23+
24+
25+
sparse_categorical_metric = partial(
26+
metric_fn, cast_strategy=casts.sparse_categorical)
27+
28+
29+
binary_true_positive = binary_metric(m.true_positive)
30+
binary_true_negative = binary_metric(m.true_negative)
31+
binary_false_positive = binary_metric(m.false_positive)
32+
binary_false_negative = binary_metric(m.false_negative)
33+
binary_precision = binary_metric(m.precision)
34+
binary_recall = binary_metric(m.recall)
35+
binary_f1_score = binary_metric(m.f1_score)
36+
37+
38+
categorical_true_positive = categorical_metric(m.true_positive)
39+
categorical_true_negative = categorical_metric(m.true_negative)
40+
categorical_false_positive = categorical_metric(m.false_positive)
41+
categorical_false_negative = categorical_metric(m.false_negative)
42+
categorical_precision = categorical_metric(m.precision)
43+
categorical_recall = categorical_metric(m.recall)
44+
categorical_f1_score = categorical_metric(m.f1_score)
45+
46+
47+
sparse_categorical_true_positive = sparse_categorical_metric(m.true_positive)
48+
sparse_categorical_true_negative = sparse_categorical_metric(m.true_negative)
49+
sparse_categorical_false_positive = sparse_categorical_metric(m.false_positive)
50+
sparse_categorical_false_negative = sparse_categorical_metric(m.false_negative)
51+
sparse_categorical_precision = sparse_categorical_metric(m.precision)
52+
sparse_categorical_recall = sparse_categorical_metric(m.recall)
53+
sparse_categorical_f1_score = sparse_categorical_metric(m.f1_score)
54+
55+
56+
# For backward compatibility.
57+
true_positive = binary_true_positive
58+
true_negative = binary_true_negative
59+
false_positive = binary_false_positive
60+
false_negative = binary_false_negative
61+
precision = binary_precision
62+
recall = binary_recall
63+
f1_score = binary_f1_score

keras_metrics/casts.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from keras import backend as K
2+
3+
4+
def binary(y_true, y_pred, dtype="int32", label=0):
5+
return categorical(y_true, y_pred, dtype, label)
6+
7+
8+
def categorical(y_true, y_pred, dtype="int32", label=0):
9+
column = slice(label, label+1)
10+
11+
y_true = y_true[..., column]
12+
y_pred = y_pred[..., column]
13+
14+
y_true = K.cast(K.round(y_true), dtype)
15+
y_pred = K.cast(K.round(y_pred), dtype)
16+
17+
return y_true, y_pred
18+
19+
20+
def sparse_categorical(y_true, y_pred, dtype="int32", label=0):
21+
y_true = K.cast(K.equal(y_true, label), dtype=dtype)
22+
23+
y_pred = y_pred[..., slice(label, label+1)]
24+
y_pred = K.cast(K.round(y_pred), dtype)
25+
26+
return y_true, y_pred

keras_metrics/metrics.py

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,40 @@
1-
from functools import partial
21
from keras import backend as K
32
from keras.layers import Layer
43
from operator import truediv
54

65

76
class layer(Layer):
87

9-
def __init__(self, label=None, name=None, sparse=False, **kwargs):
8+
def __init__(self, name=None, label=0, cast_strategy=None, **kwargs):
109
super(layer, self).__init__(name=name, **kwargs)
10+
1111
self.stateful = True
12+
self.label = label
13+
self.cast_strategy = cast_strategy
1214
self.epsilon = K.constant(K.epsilon(), dtype="float64")
13-
self.__name__ = name
14-
self.sparse = sparse
15-
16-
# If layer metric is explicitly created to evaluate specified class,
17-
# then use a binary transformation of the output arrays, otherwise
18-
# calculate an "overall" metric.
19-
if label:
20-
self.cast_strategy = partial(self._binary, label=label)
21-
else:
22-
self.cast_strategy = self._categorical
2315

2416
def cast(self, y_true, y_pred, dtype="int32"):
2517
"""Convert the specified true and predicted output to the specified
2618
destination type (int32 by default).
2719
"""
28-
return self.cast_strategy(y_true, y_pred, dtype=dtype)
29-
30-
def _binary(self, y_true, y_pred, dtype, label=0):
31-
column = slice(label, label+1)
32-
33-
# Slice a column of the output array.
34-
if self.sparse:
35-
y_true = K.cast(K.equal(y_true, label), dtype=dtype)
36-
else:
37-
y_true = y_true[...,column]
38-
y_pred = y_pred[...,column]
39-
return self._categorical(y_true, y_pred, dtype)
40-
41-
def _categorical(self, y_true, y_pred, dtype):
42-
# In case when user did not specify the label, and the shape
43-
# of the output vector has exactly two elements, we can choose
44-
# the label automatically.
45-
#
46-
# When the shape had dimension 3 and more and the label is
47-
# not specified, we should throw an error as long as calculated
48-
# metric is incorrect.
49-
labels = y_pred.shape[-1]
50-
if labels == 2:
51-
return self._binary(y_true, y_pred, dtype, label=1)
52-
elif labels > 2:
53-
raise ValueError("With 2 and more output classes a "
54-
"metric label must be specified")
55-
56-
y_true = K.cast(K.round(y_true), dtype)
57-
y_pred = K.cast(K.round(y_pred), dtype)
58-
return y_true, y_pred
20+
return self.cast_strategy(
21+
y_true, y_pred, dtype=dtype, label=self.label)
5922

6023
def __getattribute__(self, name):
6124
if name == "get_config":
6225
raise AttributeError
6326
return object.__getattribute__(self, name)
6427

28+
6529
class true_positive(layer):
6630
"""Create a metric for model's true positives amount calculation.
6731
6832
A true positive is an outcome where the model correctly predicts the
6933
positive class.
7034
"""
7135

72-
def __init__(self, name="true_positive", label=None, sparse=False, **kwargs):
73-
super(true_positive, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
36+
def __init__(self, name="true_positive", **kwargs):
37+
super(true_positive, self).__init__(name=name, **kwargs)
7438
self.tp = K.variable(0, dtype="int32")
7539

7640
def reset_states(self):
@@ -96,8 +60,8 @@ class true_negative(layer):
9660
negative class.
9761
"""
9862

99-
def __init__(self, name="true_negative", label=None, sparse=False, **kwargs):
100-
super(true_negative, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
63+
def __init__(self, name="true_negative", **kwargs):
64+
super(true_negative, self).__init__(name=name, **kwargs)
10165
self.tn = K.variable(0, dtype="int32")
10266

10367
def reset_states(self):
@@ -126,8 +90,8 @@ class false_negative(layer):
12690
negative class.
12791
"""
12892

129-
def __init__(self, name="false_negative", label=None, sparse=False, **kwargs):
130-
super(false_negative, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
93+
def __init__(self, name="false_negative", **kwargs):
94+
super(false_negative, self).__init__(name=name, **kwargs)
13195
self.fn = K.variable(0, dtype="int32")
13296

13397
def reset_states(self):
@@ -154,8 +118,8 @@ class false_positive(layer):
154118
positive class.
155119
"""
156120

157-
def __init__(self, name="false_positive", label=None, sparse=False, **kwargs):
158-
super(false_positive, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
121+
def __init__(self, name="false_positive", **kwargs):
122+
super(false_positive, self).__init__(name=name, **kwargs)
159123
self.fp = K.variable(0, dtype="int32")
160124

161125
def reset_states(self):
@@ -178,14 +142,15 @@ def __call__(self, y_true, y_pred):
178142
class recall(layer):
179143
"""Create a metric for model's recall calculation.
180144
181-
Recall measures proportion of actual positives that was identified correctly.
145+
Recall measures proportion of actual positives that was identified
146+
correctly.
182147
"""
183148

184-
def __init__(self, name="recall", label=None, sparse=False, **kwargs):
185-
super(recall, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
149+
def __init__(self, name="recall", **kwargs):
150+
super(recall, self).__init__(name=name, **kwargs)
186151

187-
self.tp = true_positive(label=label, sparse=sparse)
188-
self.fn = false_negative(label=label, sparse=sparse)
152+
self.tp = true_positive(**kwargs)
153+
self.fn = false_negative(**kwargs)
189154

190155
def reset_states(self):
191156
"""Reset the state of the metrics."""
@@ -212,11 +177,11 @@ class precision(layer):
212177
actually correct.
213178
"""
214179

215-
def __init__(self, name="precision", label=None, sparse=False, **kwargs):
216-
super(precision, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
180+
def __init__(self, name="precision", **kwargs):
181+
super(precision, self).__init__(name=name, **kwargs)
217182

218-
self.tp = true_positive(label=label, sparse=sparse)
219-
self.fp = false_positive(label=label, sparse=sparse)
183+
self.tp = true_positive(**kwargs)
184+
self.fp = false_positive(**kwargs)
220185

221186
def reset_states(self):
222187
"""Reset the state of the metrics."""
@@ -242,11 +207,11 @@ class f1_score(layer):
242207
The F1 score is the harmonic mean of precision and recall.
243208
"""
244209

245-
def __init__(self, name="f1_score", label=None, sparse=False, **kwargs):
246-
super(f1_score, self).__init__(name=name, label=label, sparse=sparse, **kwargs)
210+
def __init__(self, name="f1_score", **kwargs):
211+
super(f1_score, self).__init__(name=name, **kwargs)
247212

248-
self.precision = precision(label=label, sparse=sparse)
249-
self.recall = recall(label=label, sparse=sparse)
213+
self.precision = precision(**kwargs)
214+
self.recall = recall(**kwargs)
250215

251216
def reset_states(self):
252217
"""Reset the state of the metrics."""

0 commit comments

Comments
 (0)