Skip to content

Commit 94c4c93

Browse files
authored
Separate test for model loading (#28)
This patch defines a separate unit test for model loading.
1 parent 0bce37d commit 94c4c93

File tree

1 file changed

+42
-30
lines changed

1 file changed

+42
-30
lines changed

tests/test_metrics.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import keras_metrics
44
import itertools
55
import numpy
6+
import tempfile
67
import unittest
7-
from keras.models import load_model
88

99

1010
class TestMetrics(unittest.TestCase):
1111

12-
def test_metrics(self):
12+
def setUp(self):
1313
tp = keras_metrics.true_positive()
1414
tn = keras_metrics.true_negative()
1515
fp = keras_metrics.false_positive()
@@ -19,23 +19,52 @@ def test_metrics(self):
1919
recall = keras_metrics.recall()
2020
f1 = keras_metrics.f1_score()
2121

22-
model = keras.models.Sequential()
23-
model.add(keras.layers.Activation(keras.backend.sin))
24-
model.add(keras.layers.Activation(keras.backend.abs))
22+
self.model = keras.models.Sequential()
23+
self.model.add(keras.layers.Activation(keras.backend.sin))
24+
self.model.add(keras.layers.Activation(keras.backend.abs))
25+
26+
self.model.compile(optimizer="sgd",
27+
loss="binary_crossentropy",
28+
metrics=[tp, tn, fp, fn, precision, recall, f1])
29+
30+
def samples(self, n):
31+
x = numpy.random.uniform(0, numpy.pi/2, (n, 1))
32+
y = numpy.random.randint(2, size=(n, 1))
33+
return x, y
34+
35+
def test_save_load(self):
36+
custom_objects = {
37+
"true_positive": keras_metrics.true_positive(),
38+
"true_negative": keras_metrics.true_negative(),
39+
"false_positive": keras_metrics.false_positive(),
40+
"false_negative": keras_metrics.false_negative(),
41+
"precision": keras_metrics.precision(),
42+
"recall": keras_metrics.recall(),
43+
"f1_score": keras_metrics.f1_score(),
44+
"sin": keras.backend.sin,
45+
"abs": keras.backend.abs,
46+
}
47+
48+
x, y = self.samples(100)
49+
self.model.fit(x, y, epochs=10)
50+
51+
with tempfile.NamedTemporaryFile() as file:
52+
self.model.save(file.name, overwrite=True)
53+
model = keras.models.load_model(file.name, custom_objects=custom_objects)
2554

26-
model.compile(optimizer="sgd",
27-
loss="binary_crossentropy",
28-
metrics=[tp, tn, fp, fn, precision, recall, f1])
55+
expected = self.model.evaluate(x, y)[1:]
56+
received = model.evaluate(x, y)[1:]
2957

58+
self.assertEqual(expected, received)
59+
60+
def test_metrics(self):
3061
samples = 10000
3162
batch_size = 100
32-
lim = numpy.pi/2
3363

34-
x = numpy.random.uniform(0, lim, (samples, 1))
35-
y = numpy.random.randint(2, size=(samples, 1))
64+
x, y = self.samples(samples)
3665

37-
model.fit(x, y, epochs=10, batch_size=batch_size)
38-
metrics = model.evaluate(x, y, batch_size=batch_size)[1:]
66+
self.model.fit(x, y, epochs=10, batch_size=batch_size)
67+
metrics = self.model.evaluate(x, y, batch_size=batch_size)[1:]
3968

4069
metrics = list(map(float, metrics))
4170

@@ -67,23 +96,6 @@ def test_metrics(self):
6796
self.assertAlmostEqual(expected_recall, recall, places=places)
6897
self.assertAlmostEqual(expected_f1, f1, places=places)
6998

70-
model.save('test.hdf5', overwrite=True)
71-
72-
del model
73-
74-
custom_objects = {
75-
"true_positive": keras_metrics.true_positive(),
76-
"true_negative": keras_metrics.true_negative(),
77-
"false_positive": keras_metrics.false_negative(),
78-
"false_negative": keras_metrics.false_negative(),
79-
"precision": keras_metrics.precision(),
80-
"recall": keras_metrics.recall(),
81-
"f1_score": keras_metrics.f1_score(),
82-
"sin": keras.backend.sin,
83-
"abs": keras.backend.abs,
84-
}
85-
86-
model = load_model('test.hdf5', custom_objects=custom_objects)
8799

88100
if __name__ == "__main__":
89101
unittest.main()

0 commit comments

Comments
 (0)