15
15
"""Module that contains the `run_models` wrapper for training models from TF Model Garden."""
16
16
17
17
import os
18
+ import pickle
19
+ import shutil
18
20
from typing import Any , Dict , Optional
21
+ import uuid
19
22
23
+ from . import constants
20
24
from .. import machine_config
21
25
from .. import run
22
26
import tensorflow as tf
23
27
import tensorflow_datasets as tfds
24
28
25
- from official .core import train_lib
26
29
from official .vision .image_classification .efficientnet import efficientnet_model
27
30
from official .vision .image_classification .resnet import resnet_model
28
31
32
+ # pylint: disable=g-import-not-at-top
33
+ try :
34
+ import importlib .resources as pkg_resources
35
+ except ImportError :
36
+ # Backported for python<3.7
37
+ import importlib_resources as pkg_resources
38
+ # pylint: enable=g-import-not-at-top
39
+
29
40
30
41
def run_models (dataset_name : str ,
31
42
model_name : str ,
@@ -239,7 +250,7 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
239
250
run_experiment_kwargs: keyword arguments for `train_lib.run_experiment`.
240
251
The docs can be found at
241
252
https://github.com/tensorflow/models/blob/master/official/core/train_lib.py
242
- The distribution_strategy param is ignored because the distirbution
253
+ The distribution_strategy param is ignored because the distribution
243
254
strategy is selected based on run_kwargs.
244
255
run_kwargs: keyword arguments for `tfc.run`. The docs can be found at
245
256
https://github.com/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/core/run.py
@@ -251,48 +262,42 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
251
262
"""
252
263
if run_kwargs is None :
253
264
run_kwargs = dict ()
254
-
255
- if run .remote ():
256
- default_machine_config = machine_config .COMMON_MACHINE_CONFIGS ['T4_1X' ]
257
- if 'chief_config' in run_kwargs :
258
- chief_config = run_kwargs ['chief_config' ]
259
- else :
260
- chief_config = default_machine_config
261
- if 'worker_count' in run_kwargs :
262
- worker_count = run_kwargs ['worker_count' ]
265
+ distribution_strategy = get_distribution_strategy_str (run_kwargs )
266
+ run_experiment_kwargs .update (
267
+ dict (distribution_strategy = distribution_strategy ))
268
+ file_id = str (uuid .uuid4 ())
269
+ params_file = save_params (run_experiment_kwargs , file_id )
270
+
271
+ with pkg_resources .path (__package__ , 'models_entry_point.py' ) as path :
272
+ entry_point = f'{ file_id } .py'
273
+ shutil .copyfile (str (path ), entry_point )
274
+ run_kwargs .update (dict (entry_point = entry_point ,
275
+ distribution_strategy = None ))
276
+ info = run .run (** run_kwargs )
277
+ os .remove (entry_point )
278
+ os .remove (params_file )
279
+ return info
280
+
281
+
282
+ def get_distribution_strategy_str (run_kwargs ):
283
+ """Gets the name of a distribution strategy based on cloud run config."""
284
+ if ('worker_count' in run_kwargs
285
+ and run_kwargs ['worker_count' ] > 0 ):
286
+ if ('worker_config' in run_kwargs
287
+ and machine_config .is_tpu_config (run_kwargs ['worker_config' ])):
288
+ return 'tpu'
263
289
else :
264
- worker_count = 0
265
- if 'worker_config' in run_kwargs :
266
- worker_config = run_kwargs ['worker_config' ]
267
- else :
268
- worker_config = default_machine_config
269
- distribution_strategy = get_distribution_strategy (chief_config ,
270
- worker_count ,
271
- worker_config )
272
- run_experiment_kwargs .update (
273
- dict (distribution_strategy = distribution_strategy ))
274
- model , _ = train_lib .run_experiment (** run_experiment_kwargs )
275
- model .save (run_experiment_kwargs ['model_dir' ])
276
-
277
- run_kwargs .update (dict (entry_point = None ,
278
- distribution_strategy = None ))
279
- return run .run (** run_kwargs )
280
-
281
-
282
- def get_distribution_strategy (chief_config , worker_count , worker_config ):
283
- """Gets a tf distribution strategy based on the cloud run config."""
284
- if worker_count > 0 :
285
- if machine_config .is_tpu_config (worker_config ):
286
- # TODO(b/194857231) Dependency conflict for using TPUs
287
- resolver = tf .distribute .cluster_resolver .TPUClusterResolver (
288
- tpu = 'local' )
289
- tf .config .experimental_connect_to_cluster (resolver )
290
- tf .tpu .experimental .initialize_tpu_system (resolver )
291
- return tf .distribute .TPUStrategy (resolver )
292
- else :
293
- # TODO(b/148619319) Saving model currently failing
294
- return tf .distribute .MultiWorkerMirroredStrategy ()
295
- elif chief_config .accelerator_count > 1 :
296
- return tf .distribute .MirroredStrategy ()
290
+ return 'multi_mirror'
291
+ elif ('chief_config' in run_kwargs
292
+ and run_kwargs ['chief_config' ].accelerator_count > 1 ):
293
+ return 'mirror'
297
294
else :
298
- return tf .distribute .OneDeviceStrategy (device = '/gpu:0' )
295
+ return 'one_device'
296
+
297
+
298
+ def save_params (params , file_id ):
299
+ """Pickles the params object using the file_id as prefix."""
300
+ file_name = constants .PARAMS_FILE_NAME_FORMAT .format (file_id )
301
+ with open (file_name , 'xb' ) as f :
302
+ pickle .dump (params , f )
303
+ return file_name
0 commit comments