File tree Expand file tree Collapse file tree 2 files changed +17
-4
lines changed Expand file tree Collapse file tree 2 files changed +17
-4
lines changed Original file line number Diff line number Diff line change @@ -259,7 +259,10 @@ def maybe_enable_amp(
259259
260260
261261def init_distributed (
262- comm_config : CommConfig , enable_cpu_backend : bool = False , base_folder : str = ""
262+ comm_config : CommConfig ,
263+ enable_cpu_backend : bool = False ,
264+ base_folder : str = "" ,
265+ ranks : list [int ] | None = None ,
263266):
264267 def _warn_overwrite_env (env , val ):
265268 if env in os .environ :
@@ -303,6 +306,7 @@ def _get_distributed_backend(enable_cpu_backend):
303306 torch .distributed .init_process_group (
304307 backend = _get_distributed_backend (enable_cpu_backend ),
305308 timeout = timedelta (seconds = comm_config .init_timeout_seconds ),
309+ _ranks = ranks if ranks is not None else [],
306310 )
307311
308312
@@ -453,9 +457,7 @@ def _clip_grad_norm_with_ep(
453457 if math .isinf (norm_type ):
454458 total_norm = torch .maximum (ep_grads_total_norm , non_ep_grads_total_norm )
455459 else :
456- total_norm = (
457- ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
458- )
460+ total_norm = ep_grads_total_norm ** norm_type + non_ep_grads_total_norm ** norm_type
459461 total_norm **= 1.0 / norm_type
460462
461463 if pp_mesh is not None :
Original file line number Diff line number Diff line change @@ -84,11 +84,22 @@ def __init__(self, job_config: JobConfig):
8484 # Device has to be set before creating TorchFT manager.
8585 device_module .set_device (self .device )
8686
87+ # determine the global ranks when fault tolerance is enabled
88+ global_ranks = []
89+ ft_config = job_config .fault_tolerance
90+ if ft_config .enable :
91+ group_size = ft_config .group_size
92+ replica_id = ft_config .replica_id
93+ first_rank = replica_id * group_size
94+ last_rank = first_rank + group_size - 1
95+ global_ranks = list (range (first_rank , last_rank + 1 ))
96+
8797 # init distributed and build meshes
8898 dist_utils .init_distributed (
8999 job_config .comm ,
90100 enable_cpu_backend = job_config .training .enable_cpu_offload ,
91101 base_folder = job_config .job .dump_folder ,
102+ ranks = global_ranks ,
92103 )
93104
94105 job_config .maybe_log ()
You can’t perform that action at this time.
0 commit comments