Skip to content

Commit b6de193

Browse files
juanuribe28Tensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
Use external entry_point for run_experiment_cloud instead of loading script.
PiperOrigin-RevId: 389246791
1 parent cb4108f commit b6de193

File tree

7 files changed

+286
-124
lines changed

7 files changed

+286
-124
lines changed

src/python/dependencies.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def make_required_install_packages():
2727
"tensorflow>=1.15.0,<3.0",
2828
"tensorflow_datasets",
2929
"tensorflow_transform",
30+
"tf-models-official",
31+
"importlib_resources ; python_version<'3.7'"
3032
]
3133

3234

@@ -38,4 +40,5 @@ def make_required_test_packages():
3840
"numpy",
3941
"nbconvert",
4042
"tf-models-official",
43+
"importlib_resources ; python_version<'3.7'"
4144
]

src/python/tensorflow_cloud/core/containerize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _get_file_path_map(self):
285285
self.entry_point = sys.argv[0]
286286

287287
# Map entry_point directory to the dst directory.
288-
if not self.called_from_notebook:
288+
if not self.called_from_notebook or self.entry_point is not None:
289289
entry_point_dir, _ = os.path.split(self.entry_point)
290290
if not entry_point_dir: # Current directory
291291
entry_point_dir = "."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""""Module that contains some constantsused by the experimental module."""
16+
17+
PARAMS_FILE_NAME_FORMAT = '{}_params'

src/python/tensorflow_cloud/core/experimental/models.py

+50-45
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,28 @@
1515
"""Module that contains the `run_models` wrapper for training models from TF Model Garden."""
1616

1717
import os
18+
import pickle
19+
import shutil
1820
from typing import Any, Dict, Optional
21+
import uuid
1922

23+
from . import constants
2024
from .. import machine_config
2125
from .. import run
2226
import tensorflow as tf
2327
import tensorflow_datasets as tfds
2428

25-
from official.core import train_lib
2629
from official.vision.image_classification.efficientnet import efficientnet_model
2730
from official.vision.image_classification.resnet import resnet_model
2831

32+
# pylint: disable=g-import-not-at-top
33+
try:
34+
import importlib.resources as pkg_resources
35+
except ImportError:
36+
# Backported for python<3.7
37+
import importlib_resources as pkg_resources
38+
# pylint: enable=g-import-not-at-top
39+
2940

3041
def run_models(dataset_name: str,
3142
model_name: str,
@@ -239,7 +250,7 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
239250
run_experiment_kwargs: keyword arguments for `train_lib.run_experiment`.
240251
The docs can be found at
241252
https://github.com/tensorflow/models/blob/master/official/core/train_lib.py
242-
The distribution_strategy param is ignored because the distirbution
253+
The distribution_strategy param is ignored because the distribution
243254
strategy is selected based on run_kwargs.
244255
run_kwargs: keyword arguments for `tfc.run`. The docs can be found at
245256
https://github.com/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/core/run.py
@@ -251,48 +262,42 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
251262
"""
252263
if run_kwargs is None:
253264
run_kwargs = dict()
254-
255-
if run.remote():
256-
default_machine_config = machine_config.COMMON_MACHINE_CONFIGS['T4_1X']
257-
if 'chief_config' in run_kwargs:
258-
chief_config = run_kwargs['chief_config']
259-
else:
260-
chief_config = default_machine_config
261-
if 'worker_count' in run_kwargs:
262-
worker_count = run_kwargs['worker_count']
265+
distribution_strategy = get_distribution_strategy_str(run_kwargs)
266+
run_experiment_kwargs.update(
267+
dict(distribution_strategy=distribution_strategy))
268+
file_id = str(uuid.uuid4())
269+
params_file = save_params(run_experiment_kwargs, file_id)
270+
271+
with pkg_resources.path(__package__, 'models_entry_point.py') as path:
272+
entry_point = f'{file_id}.py'
273+
shutil.copyfile(str(path), entry_point)
274+
run_kwargs.update(dict(entry_point=entry_point,
275+
distribution_strategy=None))
276+
info = run.run(**run_kwargs)
277+
os.remove(entry_point)
278+
os.remove(params_file)
279+
return info
280+
281+
282+
def get_distribution_strategy_str(run_kwargs):
283+
"""Gets the name of a distribution strategy based on cloud run config."""
284+
if ('worker_count' in run_kwargs
285+
and run_kwargs['worker_count'] > 0):
286+
if ('worker_config' in run_kwargs
287+
and machine_config.is_tpu_config(run_kwargs['worker_config'])):
288+
return 'tpu'
263289
else:
264-
worker_count = 0
265-
if 'worker_config' in run_kwargs:
266-
worker_config = run_kwargs['worker_config']
267-
else:
268-
worker_config = default_machine_config
269-
distribution_strategy = get_distribution_strategy(chief_config,
270-
worker_count,
271-
worker_config)
272-
run_experiment_kwargs.update(
273-
dict(distribution_strategy=distribution_strategy))
274-
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
275-
model.save(run_experiment_kwargs['model_dir'])
276-
277-
run_kwargs.update(dict(entry_point=None,
278-
distribution_strategy=None))
279-
return run.run(**run_kwargs)
280-
281-
282-
def get_distribution_strategy(chief_config, worker_count, worker_config):
283-
"""Gets a tf distribution strategy based on the cloud run config."""
284-
if worker_count > 0:
285-
if machine_config.is_tpu_config(worker_config):
286-
# TODO(b/194857231) Dependency conflict for using TPUs
287-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
288-
tpu='local')
289-
tf.config.experimental_connect_to_cluster(resolver)
290-
tf.tpu.experimental.initialize_tpu_system(resolver)
291-
return tf.distribute.TPUStrategy(resolver)
292-
else:
293-
# TODO(b/148619319) Saving model currently failing
294-
return tf.distribute.MultiWorkerMirroredStrategy()
295-
elif chief_config.accelerator_count > 1:
296-
return tf.distribute.MirroredStrategy()
290+
return 'multi_mirror'
291+
elif ('chief_config' in run_kwargs
292+
and run_kwargs['chief_config'].accelerator_count > 1):
293+
return 'mirror'
297294
else:
298-
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
295+
return 'one_device'
296+
297+
298+
def save_params(params, file_id):
299+
"""Pickles the params object using the file_id as prefix."""
300+
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(file_id)
301+
with open(file_name, 'xb') as f:
302+
pickle.dump(params, f)
303+
return file_name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Entry point file for run_experiment_cloud."""
16+
17+
import os
18+
import pickle
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_cloud.core.experimental import constants
23+
from official.core import train_lib
24+
25+
26+
def load_params(file_name):
27+
with open(file_name, 'rb') as f:
28+
params = pickle.load(f)
29+
return params
30+
31+
32+
def get_tpu_strategy():
33+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
34+
tpu='local')
35+
tf.config.experimental_connect_to_cluster(resolver)
36+
tf.tpu.experimental.initialize_tpu_system(resolver)
37+
return tf.distribute.TPUStrategy(resolver)
38+
39+
40+
def get_one_device():
41+
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
42+
43+
_DISTRIBUTION_STRATEGIES = dict(
44+
# TODO(b/194857231) Dependency conflict for using TPUs
45+
tpu=get_tpu_strategy,
46+
# TODO(b/148619319) Saving model currently failing for multi_mirror
47+
multi_mirror=tf.distribute.MultiWorkerMirroredStrategy,
48+
mirror=tf.distribute.MirroredStrategy,
49+
one_device=get_one_device)
50+
51+
52+
def main():
53+
prefix, _ = os.path.splitext(os.path.basename(__file__))
54+
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(prefix)
55+
run_experiment_kwargs = load_params(file_name)
56+
strategy_str = run_experiment_kwargs['distribution_strategy']
57+
strategy = _DISTRIBUTION_STRATEGIES[strategy_str]()
58+
run_experiment_kwargs.update(dict(
59+
distribution_strategy=strategy))
60+
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
61+
model.save(run_experiment_kwargs['model_dir'])
62+
63+
64+
if __name__ == '__main__':
65+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for the models experimental module."""
16+
17+
from absl.testing import absltest
18+
import mock
19+
import tensorflow as tf
20+
21+
from tensorflow_cloud.core.experimental import constants
22+
from tensorflow_cloud.core.experimental import models_entry_point
23+
from official.core import base_task
24+
from official.core import config_definitions
25+
from official.core import train_lib
26+
27+
28+
class ModelsTest(absltest.TestCase):
29+
30+
def setUp(self):
31+
super(ModelsTest, self).setUp()
32+
config = mock.MagicMock(spec=config_definitions.ExperimentConfig)
33+
task = mock.MagicMock(spec=base_task.Task)
34+
self.run_experiment_kwargs = dict(task=task,
35+
mode='train_and_eval',
36+
params=config,
37+
model_dir='model_path',
38+
distribution_strategy='one_device')
39+
self.load_params = mock.patch.object(
40+
models_entry_point,
41+
'load_params',
42+
autospec=True,
43+
return_value=self.run_experiment_kwargs,
44+
).start()
45+
46+
self.strategy = mock.patch.object(
47+
tf.distribute,
48+
'OneDeviceStrategy',
49+
autospec=True,
50+
return_value='one_device_strategy',
51+
).start()
52+
53+
self.model = mock.MagicMock()
54+
self.run_experiment = mock.patch.object(
55+
train_lib,
56+
'run_experiment',
57+
autospec=True,
58+
return_value=(self.model, {})
59+
).start()
60+
61+
def tearDown(self):
62+
mock.patch.stopall()
63+
super(ModelsTest, self).tearDown()
64+
65+
def test_main(self):
66+
models_entry_point.main()
67+
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(
68+
'models_entry_point')
69+
self.load_params.assert_called_with(file_name)
70+
self.run_experiment_kwargs.update(dict(
71+
distribution_strategy='one_device_strategy'))
72+
self.run_experiment.assert_called_with(**self.run_experiment_kwargs)
73+
self.model.save.assert_called_with(
74+
self.run_experiment_kwargs['model_dir'])
75+
76+
if __name__ == '__main__':
77+
absltest.main()

0 commit comments

Comments
 (0)