@@ -139,7 +139,18 @@ def _get_event_busbw_factor(evt):
139
139
140
140
return correction_factor_func (group_size )
141
141
142
- def _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
142
+ def _is_uneven_all_to_all_evt (evt ):
143
+ coll_name = _get_dict_value (
144
+ evt ["args" ],
145
+ "Collective name" ,
146
+ f'Missing "Collective name" in event: { evt } '
147
+ )
148
+ return (coll_name in ["all_to_all" , "all_to_allv" ]
149
+ and (ast .literal_eval (evt ['args' ]['In split size' ])
150
+ or ast .literal_eval (evt ['args' ]['Out split size' ]))
151
+ )
152
+
153
+ def _get_uneven_all_to_all_data_size (evt , global_rank ):
143
154
group_size = evt ["args" ]["Group size" ]
144
155
local_rank = _parse_ranks (evt ["args" ]["Process Group Ranks" ], group_size ).index (global_rank )
145
156
in_elems_count = evt ["args" ]["In msg nelems" ]
@@ -158,7 +169,10 @@ def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
158
169
else :
159
170
recv_elems = out_elems_count / group_size * (group_size - 1 )
160
171
161
- return round (max (send_elems , recv_elems ) * dtype_size / evt ["dur" ] * 1e-3 , 2 )
172
+ return max (send_elems , recv_elems ) * dtype_size
173
+
174
+ def _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
175
+ return round (_get_uneven_all_to_all_data_size (evt , global_rank ) / evt ["dur" ] * 1e-3 , 2 )
162
176
163
177
def calculate_bw_ (trace_data , global_rank ):
164
178
nccl_events = [
@@ -184,10 +198,7 @@ def calculate_bw_(trace_data, global_rank):
184
198
185
199
algbw = _calculate_algbw (evt )
186
200
busbw_factor = _get_event_busbw_factor (evt )
187
- if (coll_name in ["all_to_all" , "all_to_allv" ]
188
- and (ast .literal_eval (evt ['args' ]['In split size' ])
189
- or ast .literal_eval (evt ['args' ]['Out split size' ]))
190
- ):
201
+ if _is_uneven_all_to_all_evt (evt ):
191
202
# calculate busbw for uneven all_to_all
192
203
busbw = _calculate_busbw_for_uneven_all_to_all (evt , global_rank )
193
204
else :
@@ -206,7 +217,7 @@ def calculate_bw_(trace_data, global_rank):
206
217
logger .error (f"- Error: { err_msg } " )
207
218
208
219
209
- def calculate_sbw (trace_data ):
220
+ def calculate_sbw (trace_data , global_rank ):
210
221
# calculate shared bw per rank
211
222
nccl_events = [
212
223
i
@@ -221,6 +232,8 @@ def calculate_sbw(trace_data):
221
232
222
233
total_data_size = sum (
223
234
_calculate_event_data_size (evt ) * _get_event_busbw_factor (evt )
235
+ if not _is_uneven_all_to_all_evt (evt )
236
+ else _get_uneven_all_to_all_data_size (evt , global_rank )
224
237
for evt in nccl_events
225
238
)
226
239
@@ -336,7 +349,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
336
349
) as f :
337
350
json .dump (trace , f )
338
351
339
- sbw_lst .append (calculate_sbw (trace ))
352
+ sbw_lst .append (calculate_sbw (trace , global_rank ))
340
353
341
354
pick_iter_e2e_time_ (trace , iter_e2e_time )
342
355
pick_comm_bw_ (trace , comm_bw_data )
@@ -367,7 +380,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
367
380
f"avg. E2ETime of iters among all ranks: { sum (iter_e2e_time ) / len (iter_e2e_time ) / 1e3 :.3f} ms\n "
368
381
)
369
382
f .write (
370
- f"avg. SharedBW (i.e. sum(data_size * busbw_factor ) / GPU_comm_busy_time per rank) among all ranks: { sum (sbw_lst ) / len (sbw_lst ) :.3f} GB/s\n "
383
+ f"avg. SharedBW (i.e. sum(busbw_data_size ) / GPU_comm_busy_time per rank) among all ranks: { sum (sbw_lst ) / len (sbw_lst ) :.3f} GB/s\n "
371
384
)
372
385
373
386
f .write (
0 commit comments