@@ -260,6 +260,13 @@ def pick_iter_e2e_time_(trace_data, tl):
260260
261261def pick_comm_bw_ (trace_data , comm_bw_data ):
262262 rank = trace_data ["distributedInfo" ]["rank" ]
263+
264+ group_ranks_to_pg_id = defaultdict (list )
265+ for pg in trace_data ["distributedInfo" ]["pg_config" ]:
266+ group_ranks_to_pg_id [tuple (pg ["ranks" ])].append (int (pg ["pg_name" ]))
267+ for ranks in group_ranks_to_pg_id :
268+ group_ranks_to_pg_id [ranks ].sort ()
269+
263270 nccl_events = [
264271 i
265272 for i in trace_data ["traceEvents" ]
@@ -275,10 +282,10 @@ def pick_comm_bw_(trace_data, comm_bw_data):
275282
276283 ranks = _parse_ranks (evt ["args" ]["Process Group Ranks" ], ranks_count )
277284 pg_id = int (evt ["args" ]["Process Group Name" ])
278- pg = (* ranks , pg_id ) if ranks and rank == min (ranks ) else None
285+ # If there are multiple process groups with the same ranks, the last element
286+ # of this tuple is the idential index to differentiate them across ranks.
287+ pg = (* ranks , group_ranks_to_pg_id [tuple (ranks )].index (pg_id ))
279288
280- # TODO: calculation of unbalanced all2all bw needs to be improved
281- # all2all is implemented by single ncclDevKernel_SendRecv() in NCCL
282289 comm_bw_data [(knl_name , coll_name , data_size , ranks_count )].append (
283290 [
284291 evt ["dur" ],
@@ -318,11 +325,12 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
318325 if not fpath .is_file ():
319326 continue
320327
321- global_rank = int (re .search (r"rank-(\d+)" , fpath .name ).group (1 ))
322328 with open (fpath .path , "r" , encoding = "utf-8" ) as f :
323329 trace = json .load (f )
324-
330+
331+ global_rank = trace ["distributedInfo" ]["rank" ]
325332 calculate_bw_ (trace , global_rank )
333+
326334 with open (
327335 os .path .join (processed_trace_dir , fpath .name ), "w" , encoding = "utf-8"
328336 ) as f :
0 commit comments