Skip to content

Commit f5e90a2

Browse files
Fix save_model and load_model (#19924)
1 parent 5f52db2 commit f5e90a2

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

keras/src/saving/saving_lib.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
160160
f.write(config_json.encode())
161161

162162
weights_file_path = None
163+
weights_store = None
164+
asset_store = None
165+
write_zf = False
163166
try:
164167
if weights_format == "h5":
165168
if isinstance(fileobj, io.BufferedWriter):
@@ -168,6 +171,7 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
168171
working_dir = pathlib.Path(fileobj.name).parent
169172
weights_file_path = working_dir / _VARS_FNAME_H5
170173
weights_store = H5IOStore(weights_file_path, mode="w")
174+
write_zf = True
171175
else:
172176
# Fall back when `fileobj` is an `io.BytesIO`. Typically,
173177
# this usage is for pickling.
@@ -196,13 +200,17 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
196200
)
197201
except:
198202
# Skip the final `zf.write` if any exception is raised
199-
weights_file_path = None
203+
write_zf = False
200204
raise
201205
finally:
202-
weights_store.close()
203-
asset_store.close()
204-
if weights_file_path:
206+
if weights_store:
207+
weights_store.close()
208+
if asset_store:
209+
asset_store.close()
210+
if write_zf and weights_file_path:
205211
zf.write(weights_file_path, weights_file_path.name)
212+
if weights_file_path:
213+
weights_file_path.unlink()
206214

207215

208216
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
@@ -309,15 +317,22 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
309317

310318
all_filenames = zf.namelist()
311319
weights_file_path = None
320+
weights_store = None
321+
asset_store = None
312322
try:
313323
if _VARS_FNAME_H5 in all_filenames:
314324
if isinstance(fileobj, io.BufferedReader):
315325
# First, extract the model.weights.h5 file, then load it
316326
# using h5py.
317327
working_dir = pathlib.Path(fileobj.name).parent
318-
zf.extract(_VARS_FNAME_H5, working_dir)
319-
weights_file_path = working_dir / _VARS_FNAME_H5
320-
weights_store = H5IOStore(weights_file_path, mode="r")
328+
try:
329+
zf.extract(_VARS_FNAME_H5, working_dir)
330+
weights_file_path = working_dir / _VARS_FNAME_H5
331+
weights_store = H5IOStore(weights_file_path, mode="r")
332+
except OSError:
333+
# Fall back when it is a read-only system
334+
weights_file_path = None
335+
weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode="r")
321336
else:
322337
# Fall back when `fileobj` is an `io.BytesIO`. Typically,
323338
# this usage is for pickling.
@@ -331,8 +346,6 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
331346

332347
if len(all_filenames) > 3:
333348
asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r")
334-
else:
335-
asset_store = None
336349

337350
failed_saveables = set()
338351
error_msgs = {}
@@ -346,7 +359,8 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
346359
error_msgs=error_msgs,
347360
)
348361
finally:
349-
weights_store.close()
362+
if weights_store:
363+
weights_store.close()
350364
if asset_store:
351365
asset_store.close()
352366
if weights_file_path:

keras/src/saving/saving_lib_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def save_own_variables(self, store):
634634
with zipfile.ZipFile(filepath) as zf:
635635
all_filenames = zf.namelist()
636636
self.assertNotIn("model.weights.h5", all_filenames)
637+
self.assertFalse(Path(filepath).with_name("model.weights.h5").exists())
637638

638639
def test_load_model_exception_raised(self):
639640
# Assume we have an error in `load_own_variables`.

0 commit comments

Comments
 (0)