Skip to content

Commit 113c54c

Browse files
committed
fix statistics to process groups across ranks
1 parent ac529da commit 113c54c

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

et_replay/comm/profiler_trace_analysis.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,13 @@ def pick_iter_e2e_time_(trace_data, tl):
260260

261261
def pick_comm_bw_(trace_data, comm_bw_data):
262262
rank = trace_data["distributedInfo"]["rank"]
263+
264+
group_ranks_to_pg_id = defaultdict(list)
265+
for pg in trace_data["distributedInfo"]["pg_config"]:
266+
group_ranks_to_pg_id[tuple(pg["ranks"])].append(int(pg["pg_name"]))
267+
for ranks in group_ranks_to_pg_id:
268+
group_ranks_to_pg_id[ranks].sort()
269+
263270
nccl_events = [
264271
i
265272
for i in trace_data["traceEvents"]
@@ -275,10 +282,10 @@ def pick_comm_bw_(trace_data, comm_bw_data):
275282

276283
ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count)
277284
pg_id = int(evt["args"]["Process Group Name"])
278-
pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None
285+
# If there are multiple process groups with the same ranks, the last element
286+
# of this tuple is the idential index to differentiate them across ranks.
287+
pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id))
279288

280-
# TODO: calculation of unbalanced all2all bw needs to be improved
281-
# all2all is implemented by single ncclDevKernel_SendRecv() in NCCL
282289
comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append(
283290
[
284291
evt["dur"],
@@ -318,11 +325,12 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
318325
if not fpath.is_file():
319326
continue
320327

321-
global_rank = int(re.search(r"rank-(\d+)", fpath.name).group(1))
322328
with open(fpath.path, "r", encoding="utf-8") as f:
323329
trace = json.load(f)
324-
330+
331+
global_rank = trace["distributedInfo"]["rank"]
325332
calculate_bw_(trace, global_rank)
333+
326334
with open(
327335
os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8"
328336
) as f:

0 commit comments

Comments
 (0)