Skip to content

Commit 13ba6c7

Browse files
authored
Multi-label F1 score (#11)
This patch adds support of multi-label F1 score as well as unit test to ensure that F1 is calculated correctly.
1 parent 4d1281a commit 13ba6c7

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

keras_metrics/metrics.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -218,21 +218,21 @@ def __call__(self, y_true, y_pred):
218218

219219

220220
class f1_score(layer):
221-
"""Create a metric for the model's F1 score calculation.
221+
"""Create a metric for the model's F1 score calculation.
222222
223223
The F1 score is the harmonic mean of precision and recall.
224224
"""
225225

226-
def __init__(self, name="f1_score", **kwargs):
227-
super(f1_score, self).__init__(name=name, **kwargs)
226+
def __init__(self, name="f1_score", label=None, **kwargs):
227+
super(f1_score, self).__init__(name=name, label=label, **kwargs)
228228

229-
self.pr = precision()
230-
self.rec = recall()
229+
self.precision = precision(label=label)
230+
self.recall = recall(label=label)
231231

232232
def reset_states(self):
233233
"""Reset the state of the metrics."""
234-
self.pr.reset_states()
235-
self.rec.reset_states()
234+
self.precision.reset_states()
235+
self.recall.reset_states()
236236

237237
def __call__(self, y_true, y_pred):
238238
pr = self.precision(y_true, y_pred)

tests/test_metrics.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ def test_metrics(self):
1313

1414
precision = keras_metrics.precision()
1515
recall = keras_metrics.recall()
16+
f1 = keras_metrics.f1_score()
1617

1718
model = keras.models.Sequential()
1819
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
1920
model.add(keras.layers.Dense(1, activation="softmax"))
2021

2122
model.compile(optimizer="sgd",
2223
loss="binary_crossentropy",
23-
metrics=[tp, fp, fn, precision, recall])
24+
metrics=[tp, fp, fn, precision, recall, f1])
2425

2526
samples = 1000
2627
x = numpy.random.random((samples, 2))
@@ -35,12 +36,18 @@ def test_metrics(self):
3536

3637
precision = metrics[3]
3738
recall = metrics[4]
39+
f1 = metrics[5]
3840

3941
expected_precision = tp_val / (tp_val + fp_val)
4042
expected_recall = tp_val / (tp_val + fn_val)
4143

44+
f1_divident = (expected_precision*expected_recall)
45+
f1_divisor = (expected_precision+expected_recall)
46+
expected_f1 = (2 * f1_divident / f1_divisor)
47+
4248
self.assertAlmostEqual(expected_precision, precision, delta=0.05)
4349
self.assertAlmostEqual(expected_recall, recall, delta=0.05)
50+
self.assertAlmostEqual(expected_f1, f1, delta=0.05)
4451

4552

4653
if __name__ == "__main__":

0 commit comments

Comments
 (0)