Skip to content

Commit 06a1cad

Browse files
juanuribe28Tensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
Add integration tests for run_experiment_cloud wrapper.
PiperOrigin-RevId: 383893019
1 parent ef005d7 commit 06a1cad

File tree

3 files changed

+169
-1
lines changed

3 files changed

+169
-1
lines changed

.github/workflows/actions.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
pip install flake8 pytest mock
2525
- name: Install tensorflow cloud from setup
2626
run: |
27-
pip install src/python/.
27+
pip install --upgrade --use-deprecated=legacy-resolver src/python/.
2828
pip install nbconvert
2929
- name: Lint with flake8
3030
run: |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Integration tests for calling run_experiment_cloud."""
16+
17+
import os
18+
import uuid
19+
20+
import tensorflow as tf
21+
import tensorflow_cloud as tfc
22+
from tensorflow_cloud.core.experimental import models
23+
from tensorflow_cloud.utils import google_api_client
24+
from official.core import task_factory
25+
from official.utils.testing import mock_task
26+
27+
# The staging bucket to use for cloudbuild as well as save the model and data.
28+
_TEST_BUCKET = os.environ["TEST_BUCKET"]
29+
_PROJECT_ID = os.environ["PROJECT_ID"]
30+
_PARENT_IMAGE = "gcr.io/deeplearning-platform-release/tf2-gpu.2-5"
31+
_BASE_PATH = f"gs://{_TEST_BUCKET}/{uuid.uuid4()}"
32+
33+
34+
class RunExperimentCloudTest(tf.test.TestCase):
35+
36+
def setUp(self):
37+
super(RunExperimentCloudTest, self).setUp()
38+
test_data_path = os.path.join(
39+
os.path.dirname(os.path.abspath(__file__)), "../testdata/"
40+
)
41+
self.requirements_txt = os.path.join(test_data_path,
42+
"requirements.txt")
43+
44+
test_config = {
45+
"trainer": {
46+
"checkpoint_interval": 10,
47+
"steps_per_loop": 10,
48+
"summary_interval": 10,
49+
"train_steps": 10,
50+
"validation_steps": 5,
51+
"validation_interval": 10,
52+
"continuous_eval_timeout": 1,
53+
"validation_summary_subdir": "validation",
54+
"optimizer_config": {
55+
"optimizer": {
56+
"type": "sgd",
57+
},
58+
"learning_rate": {
59+
"type": "constant"
60+
}
61+
}
62+
},
63+
}
64+
65+
params = mock_task.mock_experiment()
66+
params.override(test_config, is_strict=False)
67+
self.run_experiment_kwargs = dict(
68+
params=params,
69+
task=task_factory.get_task(params.task),
70+
mode="train_and_eval",
71+
)
72+
self.docker_config = tfc.DockerConfig(
73+
parent_image=_PARENT_IMAGE,
74+
image_build_bucket=_TEST_BUCKET
75+
)
76+
77+
def tpu_strategy(self):
78+
run_kwargs = dict(
79+
chief_config=tfc.COMMON_MACHINE_CONFIGS["CPU"],
80+
worker_count=1,
81+
worker_config=tfc.COMMON_MACHINE_CONFIGS["TPU"],
82+
requirements_txt=self.requirements_txt,
83+
job_labels={
84+
"job": "tpu_strategy",
85+
"team": "run_experiment_cloud_tests",
86+
},
87+
docker_config=self.docker_config,
88+
)
89+
run_experiment_kwargs = dict(
90+
model_dir=os.path.join(_BASE_PATH, "tpu", "saved_model"),
91+
**self.run_experiment_kwargs,
92+
)
93+
return models.run_experiment_cloud(run_experiment_kwargs,
94+
run_kwargs)
95+
96+
def multi_mirror_strategy(self):
97+
run_kwargs = dict(
98+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
99+
worker_count=1,
100+
worker_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
101+
requirements_txt=self.requirements_txt,
102+
job_labels={
103+
"job": "multi_mirror_strategy",
104+
"team": "run_experiment_cloud_tests",
105+
},
106+
docker_config=self.docker_config,
107+
)
108+
run_experiment_kwargs = dict(
109+
model_dir=os.path.join(_BASE_PATH, "multi_mirror", "saved_model"),
110+
**self.run_experiment_kwargs,
111+
)
112+
return models.run_experiment_cloud(run_experiment_kwargs,
113+
run_kwargs)
114+
115+
def mirror_strategy(self):
116+
run_kwargs = dict(
117+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_4X"],
118+
requirements_txt=self.requirements_txt,
119+
job_labels={
120+
"job": "mirror",
121+
"team": "run_experiment_cloud_tests",
122+
},
123+
docker_config=self.docker_config,
124+
)
125+
run_experiment_kwargs = dict(
126+
model_dir=os.path.join(_BASE_PATH, "mirror", "saved_model"),
127+
**self.run_experiment_kwargs,
128+
)
129+
return models.run_experiment_cloud(run_experiment_kwargs,
130+
run_kwargs)
131+
132+
def one_device_strategy(self):
133+
run_kwargs = dict(
134+
requirements_txt=self.requirements_txt,
135+
job_labels={
136+
"job": "one_device",
137+
"team": "run_experiment_cloud_tests",
138+
},
139+
docker_config=self.docker_config,
140+
)
141+
run_experiment_kwargs = dict(
142+
model_dir=os.path.join(_BASE_PATH, "one_device", "saved_model"),
143+
**self.run_experiment_kwargs,
144+
)
145+
# Using the default T4 GPU for this test.
146+
return models.run_experiment_cloud(run_experiment_kwargs,
147+
run_kwargs)
148+
149+
def test_run_experiment_cloud(self):
150+
track_status = {
151+
"one_device_strategy": self.one_device_strategy(),
152+
"mirror_strategy": self.mirror_strategy(),
153+
# TODO(b/148619319) Enable when bug is solved
154+
# "multi_mirror_strategy": self.multi_mirror_strategy(),
155+
# TODO(b/194857231) Enable when bug is solved
156+
# "tpu_strategy": self.tpu_strategy(),
157+
}
158+
159+
for test_name, ret_val in track_status.items():
160+
self.assertTrue(
161+
google_api_client.wait_for_aip_training_job_completion(
162+
ret_val["job_id"], _PROJECT_ID),
163+
"Job {} generated from the test: {} has failed".format(
164+
ret_val["job_id"], test_name))
165+
166+
if __name__ == "__main__":
167+
tf.test.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tf-models-official

0 commit comments

Comments
 (0)