Skip to content

Commit 68dd03e

Browse files
committed
TL/UCP: move onesided a2a sync to barrier
1 parent d83caaa commit 68dd03e

File tree

5 files changed

+287
-79
lines changed

5 files changed

+287
-79
lines changed

src/components/tl/ucp/alltoall/alltoall.c

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -70,50 +70,3 @@ ucc_status_t ucc_tl_ucp_alltoall_pairwise_init(ucc_base_coll_args_t *coll_args,
7070
out:
7171
return status;
7272
}
73-
74-
ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
75-
ucc_base_team_t *team,
76-
ucc_coll_task_t **task_h)
77-
{
78-
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
79-
ucc_tl_ucp_task_t *task;
80-
ucc_status_t status;
81-
82-
ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
83-
84-
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_GLOBAL_WORK_BUFFER)) {
85-
tl_error(UCC_TL_TEAM_LIB(tl_team),
86-
"global work buffer not provided nor associated with team");
87-
status = UCC_ERR_NOT_SUPPORTED;
88-
goto out;
89-
}
90-
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) {
91-
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) {
92-
tl_error(UCC_TL_TEAM_LIB(tl_team),
93-
"non memory mapped buffers are not supported");
94-
status = UCC_ERR_NOT_SUPPORTED;
95-
goto out;
96-
}
97-
}
98-
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
99-
coll_args->args.src_memh.global_memh = NULL;
100-
}
101-
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) {
102-
coll_args->args.dst_memh.global_memh = NULL;
103-
} else {
104-
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_DST_MEMH_GLOBAL)) {
105-
tl_error(UCC_TL_TEAM_LIB(tl_team),
106-
"onesided alltoall requires global memory handles for dst buffers");
107-
status = UCC_ERR_INVALID_PARAM;
108-
goto out;
109-
}
110-
}
111-
112-
task = ucc_tl_ucp_init_task(coll_args, team);
113-
*task_h = &task->super;
114-
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
115-
task->super.progress = ucc_tl_ucp_alltoall_onesided_progress;
116-
status = UCC_OK;
117-
out:
118-
return status;
119-
}

src/components/tl/ucp/alltoall/alltoall_onesided.c

Lines changed: 268 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,291 @@
1111
#include "utils/ucc_math.h"
1212
#include "tl_ucp_sendrecv.h"
1313

14+
#define CONGESTION_THRESHOLD 8
15+
#define GET_OP 0
16+
#define PUT_OP 1
17+
1418
void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask);
1519

16-
ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
20+
ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_start(ucc_coll_task_t *ctask)
1721
{
18-
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
19-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
20-
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
21-
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
22-
size_t nelems = TASK_ARGS(task).src.info.count;
23-
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
24-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
25-
ucc_rank_t start = (grank + 1) % gsize;
26-
long *pSync = TASK_ARGS(task).global_work_buffer;
27-
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
28-
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
29-
ucc_rank_t peer;
22+
return ucc_schedule_start(ctask);
23+
}
24+
25+
ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask)
26+
{
27+
ucc_schedule_t *schedule = ucc_derived_of(ctask, ucc_schedule_t);
28+
ucc_status_t status;
29+
30+
status = ucc_schedule_finalize(ctask);
31+
ucc_tl_ucp_put_schedule(schedule);
32+
return status;
33+
}
34+
35+
ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize(ucc_coll_task_t *coll_task)
36+
{
37+
ucc_status_t status;
38+
39+
status = ucc_tl_ucp_coll_finalize(coll_task);
40+
if (ucc_unlikely(UCC_OK != status)) {
41+
tl_error(UCC_TASK_LIB(coll_task), "failed to finalize collective");
42+
}
43+
return status;
44+
}
3045

31-
if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) {
32-
src_memh = TASK_ARGS(task).src_memh.global_memh[grank];
46+
void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask)
47+
{
48+
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
49+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
50+
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
51+
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
52+
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
53+
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
54+
ucc_rank_t start = (grank + 1) % gsize;
55+
ucc_rank_t peer = (start + task->alltoall_onesided.iteration) % gsize;
56+
int iteration = task->alltoall_onesided.iteration;
57+
size_t nreqs = task->alltoall_onesided.tokens;
58+
int64_t polls = 0;
59+
int64_t npolls = task->alltoall_onesided.npolls;
60+
int use_op = task->alltoall_onesided.op;
61+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
62+
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
63+
uint32_t *posted;
64+
uint32_t *completed;
65+
size_t nelems;
66+
67+
nelems = TASK_ARGS(task).src.info.count;
68+
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
69+
70+
if (use_op == GET_OP) {
71+
/* these are not typos */
72+
src_memh = TASK_ARGS(task).dst_memh.local_memh;
73+
dst_memh = TASK_ARGS(task).src_memh.global_memh;
74+
posted = &task->onesided.get_posted;
75+
completed = &task->onesided.get_completed;
76+
} else {
77+
posted = &task->onesided.put_posted;
78+
completed = &task->onesided.put_completed;
3379
}
80+
for (; *posted < gsize; peer = (peer + 1) % gsize) {
81+
if (use_op == PUT_OP) {
82+
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
83+
(void *)PTR_OFFSET(dest, grank * nelems), nelems, peer, src_memh, dst_memh, team, task),
84+
task, out);
85+
} else {
86+
UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems),
87+
(void *)PTR_OFFSET(src, grank * nelems), nelems, peer, src_memh, dst_memh, team, task),
88+
task, out);
89+
}
90+
++iteration;
91+
92+
if ((*posted - *completed) >= nreqs) {
93+
while (polls++ < npolls) {
94+
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
95+
if ((*posted - *completed) < nreqs) {
96+
break;
97+
}
98+
}
99+
if (polls >= npolls) {
100+
task->alltoall_onesided.iteration = iteration;
101+
return;
102+
}
103+
}
104+
}
105+
106+
if (!UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) {
107+
while (polls++ < npolls) {
108+
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
109+
if (UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) {
110+
goto complete;
111+
}
112+
}
113+
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
114+
if (UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) {
115+
goto complete;
116+
}
117+
return;
118+
}
119+
120+
complete:
121+
task->super.status = UCC_OK;
122+
out:
123+
return;
124+
}
125+
126+
ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
127+
{
128+
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
129+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
130+
ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer;
131+
ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer;
132+
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
133+
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
134+
ucc_rank_t start = (grank + 1) % gsize;
135+
size_t nelems = TASK_ARGS(task).src.info.count;
136+
size_t nreqs = task->alltoall_onesided.tokens;
137+
size_t npolls = task->alltoall_onesided.npolls;
138+
int iteration = 0;
139+
int64_t polls = 0;
140+
int use_op = task->alltoall_onesided.op;
141+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
142+
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
143+
uint32_t *posted;
144+
uint32_t *completed;
145+
ucc_rank_t peer;
34146

35147
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
36-
/* TODO: change when support for library-based work buffers is complete */
148+
if (use_op == GET_OP) {
149+
/* these are not typos */
150+
src_memh = TASK_ARGS(task).dst_memh.local_memh;
151+
dst_memh = TASK_ARGS(task).src_memh.global_memh;
152+
posted = &task->onesided.get_posted;
153+
completed = &task->onesided.get_completed;
154+
} else {
155+
posted = &task->onesided.put_posted;
156+
completed = &task->onesided.put_completed;
157+
}
37158
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
38-
dest = dest + grank * nelems;
39-
for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) {
40-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(
41-
(void *)(src + peer * nelems), (void *)dest, nelems,
42-
peer, src_memh, dst_memh, team, task),
159+
for (peer = start; *posted < gsize; peer = (peer + 1) % gsize) {
160+
if (use_op == PUT_OP) {
161+
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
162+
(void *)PTR_OFFSET(dest, grank * nelems), nelems, peer, src_memh, dst_memh, team, task),
163+
task, out);
164+
} else {
165+
UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems),
166+
(void *)PTR_OFFSET(src, grank * nelems), nelems, peer, src_memh, dst_memh, team, task),
43167
task, out);
44-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team),
45-
task, out);
168+
}
169+
++iteration;
170+
171+
if ((*posted - *completed) >= nreqs) {
172+
while (polls++ < npolls) {
173+
ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker);
174+
if ((*posted - *completed) < nreqs) {
175+
break;
176+
}
177+
}
178+
if (polls >= npolls) {
179+
break;
180+
}
181+
}
46182
}
183+
task->alltoall_onesided.iteration = iteration;
47184
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
48185
out:
49186
return task->super.status;
50187
}
51188

52-
void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask)
189+
ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
190+
ucc_base_team_t *team,
191+
ucc_coll_task_t **task_h)
53192
{
54-
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
55-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
56-
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
57-
long * pSync = TASK_ARGS(task).global_work_buffer;
193+
ucc_schedule_t *schedule = NULL;
194+
ucc_tl_ucp_team_t *tl_team =
195+
ucc_derived_of(team, ucc_tl_ucp_team_t);
196+
ucc_base_coll_args_t barrier_coll_args = {
197+
.team = team->params.team,
198+
.args.coll_type = UCC_COLL_TYPE_BARRIER,
199+
};
200+
size_t perc_bw =
201+
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw;
202+
ucc_on_off_auto_value_t is_get =
203+
UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_enable_get;
204+
ucc_coll_task_t *barrier_task;
205+
ucc_coll_task_t *a2a_task;
206+
ucc_tl_ucp_task_t *task;
207+
ucc_status_t status;
208+
ucc_tl_ucp_schedule_t *tl_schedule;
209+
size_t nelems;
210+
size_t rate;
211+
size_t ratio;
212+
ucp_ep_h ep;
213+
ucp_ep_evaluate_perf_param_t param;
214+
ucp_ep_evaluate_perf_attr_t attr;
215+
size_t count;
216+
int64_t npolls;
217+
ucc_sbgp_t *sbgp;
58218

59-
if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) {
60-
return;
219+
ALLTOALL_TASK_CHECK(coll_args->args, tl_team);
220+
if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) {
221+
if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) {
222+
tl_error(UCC_TL_TEAM_LIB(tl_team),
223+
"non memory mapped buffers are not supported");
224+
status = UCC_ERR_NOT_SUPPORTED;
225+
goto out;
226+
}
227+
}
228+
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
229+
coll_args->args.src_memh.global_memh = NULL;
61230
}
231+
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) {
232+
coll_args->args.dst_memh.global_memh = NULL;
233+
}
234+
status = ucc_tl_ucp_get_schedule(tl_team, coll_args, (ucc_tl_ucp_schedule_t **)&tl_schedule);
235+
if (ucc_unlikely(UCC_OK != status)) {
236+
return status;
237+
}
238+
schedule = &tl_schedule->super.super;
239+
ucc_schedule_init(schedule, coll_args, team);
240+
schedule->super.post = ucc_tl_ucp_alltoall_onesided_sched_start;
241+
schedule->super.progress = NULL;
242+
schedule->super.finalize = ucc_tl_ucp_alltoall_onesided_sched_finalize;
62243

63-
pSync[0] = 0;
64-
task->super.status = UCC_OK;
244+
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_NODE);
245+
task = ucc_tl_ucp_init_task(coll_args, team);
246+
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
247+
task->super.progress = ucc_tl_ucp_alltoall_onesided_progress;
248+
task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize;
249+
a2a_task = &task->super;
250+
251+
status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task);
252+
if (status != UCC_OK) {
253+
return status;
254+
}
255+
if (perc_bw > 100) {
256+
perc_bw = 100;
257+
} else if (perc_bw == 0) {
258+
perc_bw = 1;
259+
}
260+
261+
nelems = TASK_ARGS(task).src.info.count;
262+
nelems = nelems / UCC_TL_TEAM_SIZE(tl_team);
263+
param.field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE;
264+
attr.field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME;
265+
param.message_size = (1 << 20);
266+
ucc_tl_ucp_get_ep(tl_team,
267+
(UCC_TL_TEAM_RANK(tl_team) + 1) % UCC_TL_TEAM_SIZE(tl_team), &ep);
268+
ucp_ep_evaluate_perf(ep, &param, &attr);
269+
270+
rate = (param.message_size / attr.estimated_time) / (param.message_size);
271+
rate = rate * (double)(perc_bw / 100.0);
272+
ratio = nelems * sbgp->group_size;
273+
task->alltoall_onesided.tokens = rate / ratio;
274+
if (task->alltoall_onesided.tokens < 1) {
275+
task->alltoall_onesided.tokens = 1;
276+
}
277+
npolls = task->n_polls;
278+
if (is_get == UCC_CONFIG_ON || (is_get == UCC_CONFIG_AUTO && sbgp->group_size >= CONGESTION_THRESHOLD)) {
279+
int64_t polls;
280+
count = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
281+
polls = count - task->alltoall_onesided.tokens;
282+
if (polls > task->n_polls) {
283+
npolls = polls;
284+
}
285+
task->alltoall_onesided.op = GET_OP;
286+
} else {
287+
task->alltoall_onesided.op = PUT_OP;
288+
}
289+
task->alltoall_onesided.npolls = npolls;
290+
task->alltoall_onesided.iteration = 0;
291+
292+
ucc_schedule_add_task(schedule, a2a_task);
293+
ucc_task_subscribe_dep(&schedule->super, a2a_task, UCC_EVENT_SCHEDULE_STARTED);
294+
295+
ucc_schedule_add_task(schedule, barrier_task);
296+
ucc_task_subscribe_dep(a2a_task, barrier_task,
297+
UCC_EVENT_COMPLETED);
298+
*task_h = &schedule->super;
299+
out:
300+
return status;
65301
}

src/components/tl/ucp/tl_ucp.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
9595
ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_hybrid_chunk_byte_limit),
9696
UCC_CONFIG_TYPE_MEMUNITS},
9797

98+
{"ALLTOALL_ONESIDED_PERCENT_BW", "100",
99+
"Percentage (1-100) of NIC bandwidth to use for congestion avoidance "
100+
"(default: 100)",
101+
ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoall_onesided_percent_bw),
102+
UCC_CONFIG_TYPE_UINT},
103+
104+
{"ALLTOALL_ONESIDED_ENABLE_GET", "auto",
105+
"Enable use of GET-based algorithm for onesided alltoall (default: auto)",
106+
ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoall_onesided_enable_get),
107+
UCC_CONFIG_TYPE_ON_OFF_AUTO},
108+
98109
{"KN_RADIX", "0",
99110
"Radix of all algorithms based on knomial pattern. When set to a "
100111
"positive value it is used as a convenience parameter to set all "

0 commit comments

Comments
 (0)