Skip to content

Commit 67171fa

Browse files
feature: Add heterogeneous cluster changes (#421)
* feature: Add heterogeneous cluster changes
1 parent 30fed8f commit 67171fa

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def read_version():
4545
"botocore==1.19.34",
4646
"requests-mock",
4747
"awscli==1.18.194",
48-
"protobuf>=3.20,<3.21"
48+
"protobuf>=3.9.2,<3.20"
4949
]
5050

5151
if sys.version_info.major > 2:

src/sagemaker_tensorflow_container/training.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,20 @@ def train(env, cmd_args):
176176
env_vars = env.to_env_vars()
177177

178178
# Setup
179-
if parameter_server_enabled:
179+
if env.current_instance_group in env.distribution_instance_groups:
180+
if parameter_server_enabled:
180181

181-
tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
182-
logger.info("Running distributed training job with parameter servers")
182+
tf_config = _build_tf_config_for_ps(hosts=env.distribution_hosts, current_host=env.current_host)
183+
logger.info("Running distributed training job with parameter servers")
183184

184-
elif multi_worker_mirrored_strategy_enabled:
185+
elif multi_worker_mirrored_strategy_enabled:
185186

186-
env_vars["TF_CONFIG"] = json.dumps(
187-
_build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
188-
)
189-
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
187+
env_vars["TF_CONFIG"] = json.dumps(
188+
_build_tf_config_for_mwms(hosts=env.distribution_hosts, current_host=env.current_host)
189+
)
190+
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
191+
192+
runner_type = runner.ProcessRunnerType
190193

191194
# Run
192195
if parameter_server_enabled:
@@ -200,15 +203,13 @@ def train(env, cmd_args):
200203
_wait_until_master_is_down(env.hosts[0])
201204

202205
else:
206+
if env.current_instance_group in env.distribution_instance_groups:
207+
mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")
203208

204-
mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")
205-
206-
if mpi_enabled:
207-
runner_type = runner.MPIRunnerType
208-
elif sagemaker_distributed_dataparallel_enabled:
209-
runner_type = runner.SMDataParallelRunnerType
210-
else:
211-
runner_type = runner.ProcessRunnerType
209+
if mpi_enabled:
210+
runner_type = runner.MPIRunnerType
211+
elif sagemaker_distributed_dataparallel_enabled:
212+
runner_type = runner.SMDataParallelRunnerType
212213

213214
entry_point.run(
214215
uri=env.module_dir,

test/unit/test_training.py

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def distributed_training_env():
5252
env = simple_training_env()
5353

5454
env.hosts = HOST_LIST
55+
env.current_instance_group = "test1"
56+
env.distribution_hosts = ["host1", "host2"]
57+
env.distribution_instance_groups = ["test1"]
5558
env.additional_framework_parameters = {training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True}
5659
return env
5760

@@ -63,6 +66,8 @@ def single_machine_training_env():
6366

6467
def simple_training_env():
6568
env = MagicMock()
69+
env.current_instance_group = "test1"
70+
env.distribution_instance_groups = ["test1"]
6671
env.module_dir = MODULE_DIR
6772
env.user_entry_point = MODULE_NAME
6873
env.hyperparameters = {"model_dir": MODEL_DIR}

0 commit comments

Comments
 (0)