Skip to content

Commit 21efb25

Browse files
committed
fix sbw calculation with uneven all_to_all
1 parent dee58b1 commit 21efb25

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

et_replay/comm/profiler_trace_analysis.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,18 @@ def _get_event_busbw_factor(evt):
139139

140140
return correction_factor_func(group_size)
141141

142-
def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
142+
def _is_uneven_all_to_all_evt(evt):
143+
coll_name = _get_dict_value(
144+
evt["args"],
145+
"Collective name",
146+
f'Missing "Collective name" in event: {evt}'
147+
)
148+
return (coll_name in ["all_to_all", "all_to_allv"]
149+
and (ast.literal_eval(evt['args']['In split size'])
150+
or ast.literal_eval(evt['args']['Out split size']))
151+
)
152+
153+
def _get_uneven_all_to_all_data_size(evt, global_rank):
143154
group_size = evt["args"]["Group size"]
144155
local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(global_rank)
145156
in_elems_count = evt["args"]["In msg nelems"]
@@ -158,7 +169,10 @@ def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
158169
else:
159170
recv_elems = out_elems_count / group_size * (group_size - 1)
160171

161-
return round(max(send_elems, recv_elems) * dtype_size / evt["dur"] * 1e-3, 2)
172+
return max(send_elems, recv_elems) * dtype_size
173+
174+
def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
175+
return round(_get_uneven_all_to_all_data_size(evt, global_rank) / evt["dur"] * 1e-3, 2)
162176

163177
def calculate_bw_(trace_data, global_rank):
164178
nccl_events = [
@@ -184,10 +198,7 @@ def calculate_bw_(trace_data, global_rank):
184198

185199
algbw = _calculate_algbw(evt)
186200
busbw_factor = _get_event_busbw_factor(evt)
187-
if (coll_name in ["all_to_all", "all_to_allv"]
188-
and (ast.literal_eval(evt['args']['In split size'])
189-
or ast.literal_eval(evt['args']['Out split size']))
190-
):
201+
if _is_uneven_all_to_all_evt(evt):
191202
# calculate busbw for uneven all_to_all
192203
busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank)
193204
else:
@@ -206,7 +217,7 @@ def calculate_bw_(trace_data, global_rank):
206217
logger.error(f"- Error: {err_msg}")
207218

208219

209-
def calculate_sbw(trace_data):
220+
def calculate_sbw(trace_data, global_rank):
210221
# calculate shared bw per rank
211222
nccl_events = [
212223
i
@@ -221,6 +232,8 @@ def calculate_sbw(trace_data):
221232

222233
total_data_size = sum(
223234
_calculate_event_data_size(evt) * _get_event_busbw_factor(evt)
235+
if not _is_uneven_all_to_all_evt(evt)
236+
else _get_uneven_all_to_all_data_size(evt, global_rank)
224237
for evt in nccl_events
225238
)
226239

@@ -336,7 +349,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
336349
) as f:
337350
json.dump(trace, f)
338351

339-
sbw_lst.append(calculate_sbw(trace))
352+
sbw_lst.append(calculate_sbw(trace, global_rank))
340353

341354
pick_iter_e2e_time_(trace, iter_e2e_time)
342355
pick_comm_bw_(trace, comm_bw_data)
@@ -367,7 +380,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
367380
f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n"
368381
)
369382
f.write(
370-
f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n"
383+
f"avg. SharedBW (i.e. sum(busbw_data_size) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n"
371384
)
372385

373386
f.write(

0 commit comments

Comments
 (0)