@@ -176,17 +176,20 @@ def train(env, cmd_args):
176
176
env_vars = env .to_env_vars ()
177
177
178
178
# Setup
179
- if parameter_server_enabled :
179
+ if env .current_instance_group in env .distribution_instance_groups :
180
+ if parameter_server_enabled :
180
181
181
- tf_config = _build_tf_config_for_ps (hosts = env .hosts , current_host = env .current_host )
182
- logger .info ("Running distributed training job with parameter servers" )
182
+ tf_config = _build_tf_config_for_ps (hosts = env .distribution_hosts , current_host = env .current_host )
183
+ logger .info ("Running distributed training job with parameter servers" )
183
184
184
- elif multi_worker_mirrored_strategy_enabled :
185
+ elif multi_worker_mirrored_strategy_enabled :
185
186
186
- env_vars ["TF_CONFIG" ] = json .dumps (
187
- _build_tf_config_for_mwms (hosts = env .hosts , current_host = env .current_host )
188
- )
189
- logger .info ("Running distributed training job with multi_worker_mirrored_strategy setup" )
187
+ env_vars ["TF_CONFIG" ] = json .dumps (
188
+ _build_tf_config_for_mwms (hosts = env .distribution_hosts , current_host = env .current_host )
189
+ )
190
+ logger .info ("Running distributed training job with multi_worker_mirrored_strategy setup" )
191
+
192
+ runner_type = runner .ProcessRunnerType
190
193
191
194
# Run
192
195
if parameter_server_enabled :
@@ -200,15 +203,13 @@ def train(env, cmd_args):
200
203
_wait_until_master_is_down (env .hosts [0 ])
201
204
202
205
else :
206
+ if env .current_instance_group in env .distribution_instance_groups :
207
+ mpi_enabled = env .additional_framework_parameters .get ("sagemaker_mpi_enabled" )
203
208
204
- mpi_enabled = env .additional_framework_parameters .get ("sagemaker_mpi_enabled" )
205
-
206
- if mpi_enabled :
207
- runner_type = runner .MPIRunnerType
208
- elif sagemaker_distributed_dataparallel_enabled :
209
- runner_type = runner .SMDataParallelRunnerType
210
- else :
211
- runner_type = runner .ProcessRunnerType
209
+ if mpi_enabled :
210
+ runner_type = runner .MPIRunnerType
211
+ elif sagemaker_distributed_dataparallel_enabled :
212
+ runner_type = runner .SMDataParallelRunnerType
212
213
213
214
entry_point .run (
214
215
uri = env .module_dir ,
0 commit comments