Skip to content

Commit cb4108f

Browse files
LukeWoodTensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
core/validate.py#L176 checks the current TensorFlow Version on your local machine. It then asserts that the value is in the list from get_cloud_tpu_supported_tf_version which contains the list [“2.1”]. This makes it currently impossible to perform TPU training (from what I can tell, you can’t get a TF version where tf.__version__ is exactly “2.1” with no postpended subversion.
Removes this check so TPU training will work again PiperOrigin-RevId: 388815387
1 parent f1ae448 commit cb4108f

File tree

2 files changed

+0
-37
lines changed

2 files changed

+0
-37
lines changed

src/python/tensorflow_cloud/core/tests/unit/validate_test.py

-26
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717

1818
from absl.testing import absltest
19-
import mock
2019

2120
from tensorflow_cloud.core import machine_config
2221
from tensorflow_cloud.core import validate
@@ -302,30 +301,5 @@ def test_invalid_tpu_accelerator_count(self):
302301
called_from_notebook=False,
303302
)
304303

305-
@mock.patch("tensorflow_cloud.utils.tf_utils.get_version") # pylint: disable=line-too-long
306-
def test_invalid_tpu_accelerator_tf_version(self, mock_get_version):
307-
mock_get_version.return_value = "2.2.0"
308-
with self.assertRaisesRegex(
309-
NotImplementedError,
310-
r"TPUs are only supported for TF version <= 2.1.0",
311-
):
312-
validate.validate(
313-
entry_point=None,
314-
distribution_strategy="auto",
315-
requirements_txt=None,
316-
chief_config=machine_config.COMMON_MACHINE_CONFIGS["CPU"],
317-
worker_config=machine_config.MachineConfig(
318-
accelerator_type=machine_config.AcceleratorType.TPU_V2,
319-
accelerator_count=8,
320-
),
321-
worker_count=1,
322-
entry_point_args=None,
323-
stream_logs=True,
324-
docker_image_build_bucket=None,
325-
called_from_notebook=False,
326-
docker_parent_image=None,
327-
)
328-
329-
330304
if __name__ == "__main__":
331305
absltest.main()

src/python/tensorflow_cloud/core/validate.py

-11
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from . import gcp
1919
from . import machine_config
20-
from ..utils import tf_utils
2120

2221

2322
def validate(
@@ -169,16 +168,6 @@ def _validate_cluster_config(
169168
"Expected worker_count=1 for TPU `worker_config`. "
170169
"Received {}.".format(worker_count)
171170
)
172-
elif docker_parent_image is None:
173-
# If the user has not provided a custom Docker image, then verify
174-
# that the TF version is compatible with Cloud TPU support.
175-
# https://cloud.google.com/ai-platform/training/docs/runtime-version-list#tpu-support # pylint: disable=line-too-long
176-
version = tf_utils.get_version()
177-
if (version is not None and
178-
version not in gcp.get_cloud_tpu_supported_tf_versions()):
179-
raise NotImplementedError(
180-
"TPUs are only supported for TF version <= 2.1.0"
181-
)
182171

183172

184173
def _validate_job_labels(job_labels):

0 commit comments

Comments
 (0)