@@ -160,6 +160,9 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
160
160
f .write (config_json .encode ())
161
161
162
162
weights_file_path = None
163
+ weights_store = None
164
+ asset_store = None
165
+ write_zf = False
163
166
try :
164
167
if weights_format == "h5" :
165
168
if isinstance (fileobj , io .BufferedWriter ):
@@ -168,6 +171,7 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
168
171
working_dir = pathlib .Path (fileobj .name ).parent
169
172
weights_file_path = working_dir / _VARS_FNAME_H5
170
173
weights_store = H5IOStore (weights_file_path , mode = "w" )
174
+ write_zf = True
171
175
else :
172
176
# Fall back when `fileobj` is an `io.BytesIO`. Typically,
173
177
# this usage is for pickling.
@@ -196,13 +200,17 @@ def _save_model_to_fileobj(model, fileobj, weights_format):
196
200
)
197
201
except :
198
202
# Skip the final `zf.write` if any exception is raised
199
- weights_file_path = None
203
+ write_zf = False
200
204
raise
201
205
finally :
202
- weights_store .close ()
203
- asset_store .close ()
204
- if weights_file_path :
206
+ if weights_store :
207
+ weights_store .close ()
208
+ if asset_store :
209
+ asset_store .close ()
210
+ if write_zf and weights_file_path :
205
211
zf .write (weights_file_path , weights_file_path .name )
212
+ if weights_file_path :
213
+ weights_file_path .unlink ()
206
214
207
215
208
216
def load_model (filepath , custom_objects = None , compile = True , safe_mode = True ):
@@ -309,15 +317,22 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
309
317
310
318
all_filenames = zf .namelist ()
311
319
weights_file_path = None
320
+ weights_store = None
321
+ asset_store = None
312
322
try :
313
323
if _VARS_FNAME_H5 in all_filenames :
314
324
if isinstance (fileobj , io .BufferedReader ):
315
325
# First, extract the model.weights.h5 file, then load it
316
326
# using h5py.
317
327
working_dir = pathlib .Path (fileobj .name ).parent
318
- zf .extract (_VARS_FNAME_H5 , working_dir )
319
- weights_file_path = working_dir / _VARS_FNAME_H5
320
- weights_store = H5IOStore (weights_file_path , mode = "r" )
328
+ try :
329
+ zf .extract (_VARS_FNAME_H5 , working_dir )
330
+ weights_file_path = working_dir / _VARS_FNAME_H5
331
+ weights_store = H5IOStore (weights_file_path , mode = "r" )
332
+ except OSError :
333
+ # Fall back when it is a read-only system
334
+ weights_file_path = None
335
+ weights_store = H5IOStore (_VARS_FNAME_H5 , zf , mode = "r" )
321
336
else :
322
337
# Fall back when `fileobj` is an `io.BytesIO`. Typically,
323
338
# this usage is for pickling.
@@ -331,8 +346,6 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
331
346
332
347
if len (all_filenames ) > 3 :
333
348
asset_store = DiskIOStore (_ASSETS_DIRNAME , archive = zf , mode = "r" )
334
- else :
335
- asset_store = None
336
349
337
350
failed_saveables = set ()
338
351
error_msgs = {}
@@ -346,7 +359,8 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
346
359
error_msgs = error_msgs ,
347
360
)
348
361
finally :
349
- weights_store .close ()
362
+ if weights_store :
363
+ weights_store .close ()
350
364
if asset_store :
351
365
asset_store .close ()
352
366
if weights_file_path :
0 commit comments