Skip to content

Commit c90a3a5

Browse files
authored
Removing references to jax.config.spmd_mode('allow_all'). (#21164)
This flag no longer does anything in jax.
1 parent 2111fbc commit c90a3a5

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

Diff for: keras/src/backend/jax/trainer.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,7 @@ def fit(
449449

450450
# Override with model metrics instead of last step logs if
451451
# needed.
452-
# The jax spmd_mode is need for multi-process context, since the
453-
# metrics values are replicated, and we don't want to do a all
454-
# gather, and only need the local copy of the value.
455-
with jax.spmd_mode("allow_all"):
456-
epoch_logs = dict(self._get_metrics_result_or_logs(logs))
452+
epoch_logs = dict(self._get_metrics_result_or_logs(logs))
457453

458454
# Run validation.
459455
if validation_data is not None and self._should_eval(
@@ -605,11 +601,7 @@ def evaluate(
605601
# Reattach state back to model (if not already done by a callback).
606602
self.jax_state_sync()
607603

608-
# The jax spmd_mode is need for multi-process context, since the
609-
# metrics values are replicated, and we don't want to do a all
610-
# gather, and only need the local copy of the value.
611-
with jax.spmd_mode("allow_all"):
612-
logs = self._get_metrics_result_or_logs(logs)
604+
logs = self._get_metrics_result_or_logs(logs)
613605
callbacks.on_test_end(logs)
614606
self._jax_state = None
615607
if not use_cached_eval_dataset:

0 commit comments

Comments
 (0)