Skip to content

Commit 1e35db9

Browse files
fix metrics testing failure due to optimizer change (#2784)
1 parent 2ac80d0 commit 1e35db9

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorflow_addons/metrics/tests/streaming_correlations_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ def test_keras_binary_classification_model(self, correlation_type):
100100
inputs = tf.keras.layers.Input(shape=(128,))
101101
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs)
102102
model = tf.keras.models.Model(inputs, outputs)
103+
if hasattr(tf.keras.optimizers, "legacy"):
104+
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.1)
105+
else:
106+
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
103107
model.compile(
104-
optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
108+
optimizer=optimizer,
105109
loss="binary_crossentropy",
106110
metrics=[metric],
107111
)
@@ -128,7 +132,7 @@ def test_keras_binary_classification_model(self, correlation_type):
128132
tf.function(metric.update_state)(y, preds)
129133
metric_value = tf.function(metric.result)()
130134
scipy_value = self.scipy_corr[correlation_type](preds[:, 0], y[:, 0])[0]
131-
np.testing.assert_almost_equal(metric_value, metric_history[-1])
135+
np.testing.assert_almost_equal(metric_value, metric_history[-1], decimal=5)
132136
np.testing.assert_almost_equal(metric_value, scipy_value, decimal=2)
133137

134138
@pytest.mark.parametrize("correlation_type", testing_types)

0 commit comments

Comments
 (0)