Skip to content

Commit 6ec0f46

Browse files
committed
Fix KerasFileEditor tests
1 parent 726a38f commit 6ec0f46

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

Diff for: keras/src/saving/file_editor_test.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -37,55 +37,55 @@ def test_basics(self):
3737

3838
target_model = get_target_model()
3939

40-
out = editor.compare_to(model) # Succeeds
40+
out = editor.compare(model) # Succeeds
4141
self.assertEqual(out["status"], "success")
42-
out = editor.compare_to(target_model) # Fails
42+
out = editor.compare(target_model) # Fails
4343

4444
editor.add_object(
4545
"layers/dense_3", weights={"0": np.random.random((3, 3))}
4646
)
47-
out = editor.compare_to(target_model) # Fails
47+
out = editor.compare(target_model) # Fails
4848
self.assertEqual(out["status"], "error")
4949
self.assertEqual(out["error_count"], 2)
5050

5151
editor.rename_object("dense_3", "dense_4")
5252
editor.rename_object("layers/dense_4", "dense_2")
5353
editor.add_weights("dense_2", weights={"1": np.random.random((3,))})
54-
out = editor.compare_to(target_model) # Succeeds
54+
out = editor.compare(target_model) # Succeeds
5555
self.assertEqual(out["status"], "success")
5656

5757
editor.add_object(
5858
"layers/dense_3", weights={"0": np.random.random((3, 3))}
5959
)
60-
out = editor.compare_to(target_model) # Fails
60+
out = editor.compare(target_model) # Fails
6161
self.assertEqual(out["status"], "error")
6262
self.assertEqual(out["error_count"], 1)
6363

6464
editor.delete_object("layers/dense_3")
65-
out = editor.compare_to(target_model) # Succeeds
65+
out = editor.compare(target_model) # Succeeds
6666
self.assertEqual(out["status"], "success")
6767
editor.summary()
6868

6969
temp_filepath = os.path.join(self.get_temp_dir(), "resaved.weights.h5")
70-
editor.resave_weights(temp_filepath)
70+
editor.save(temp_filepath)
7171
target_model.load_weights(temp_filepath)
7272

7373
editor = KerasFileEditor(temp_filepath)
7474
editor.summary()
75-
out = editor.compare_to(target_model) # Succeeds
75+
out = editor.compare(target_model) # Succeeds
7676
self.assertEqual(out["status"], "success")
7777

7878
editor.delete_weight("dense_2", "1")
79-
out = editor.compare_to(target_model) # Fails
79+
out = editor.compare(target_model) # Fails
8080
self.assertEqual(out["status"], "error")
8181
self.assertEqual(out["error_count"], 1)
8282

8383
editor.add_weights("dense_2", {"1": np.zeros((7,))})
84-
out = editor.compare_to(target_model) # Fails
84+
out = editor.compare(target_model) # Fails
8585
self.assertEqual(out["status"], "error")
8686
self.assertEqual(out["error_count"], 1)
8787

8888
editor.delete_weight("dense_2", "1")
8989
editor.add_weights("dense_2", {"1": np.zeros((3,))})
90-
out = editor.compare_to(target_model) # Succeeds
90+
out = editor.compare(target_model) # Succeeds
9191
self.assertEqual(out["status"], "success")

0 commit comments

Comments
 (0)