File tree 1 file changed +2
-10
lines changed
1 file changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -449,11 +449,7 @@ def fit(
449
449
450
450
# Override with model metrics instead of last step logs if
451
451
# needed.
452
- # The jax spmd_mode is need for multi-process context, since the
453
- # metrics values are replicated, and we don't want to do a all
454
- # gather, and only need the local copy of the value.
455
- with jax .spmd_mode ("allow_all" ):
456
- epoch_logs = dict (self ._get_metrics_result_or_logs (logs ))
452
+ epoch_logs = dict (self ._get_metrics_result_or_logs (logs ))
457
453
458
454
# Run validation.
459
455
if validation_data is not None and self ._should_eval (
@@ -605,11 +601,7 @@ def evaluate(
605
601
# Reattach state back to model (if not already done by a callback).
606
602
self .jax_state_sync ()
607
603
608
- # The jax spmd_mode is need for multi-process context, since the
609
- # metrics values are replicated, and we don't want to do a all
610
- # gather, and only need the local copy of the value.
611
- with jax .spmd_mode ("allow_all" ):
612
- logs = self ._get_metrics_result_or_logs (logs )
604
+ logs = self ._get_metrics_result_or_logs (logs )
613
605
callbacks .on_test_end (logs )
614
606
self ._jax_state = None
615
607
if not use_cached_eval_dataset :
You can’t perform that action at this time.
0 commit comments