Skip to content

Commit 02147dd

Browse files
TL/UCP: make local copy nb in allgather
1 parent 75ecf74 commit 02147dd

File tree

3 files changed

+62
-12
lines changed

3 files changed

+62
-12
lines changed

src/components/tl/ucp/allgather/allgather_ring.c

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
4040
size_t count = TASK_ARGS(task).dst.info.count;
4141
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
4242
size_t data_size = (count / tsize) * ucc_dt_size(dt);
43+
ucc_status_t status = UCC_OK;
4344
ucc_rank_t sendto, recvfrom, sblock, rblock;
4445
int step;
4546
void *buf;
@@ -69,7 +70,14 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
6970
}
7071
}
7172
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
72-
task->super.status = UCC_OK;
73+
if (task->allgather_ring.etask) {
74+
status = ucc_ee_executor_task_test(task->allgather_ring.etask);
75+
if (status == UCC_INPROGRESS) {
76+
return;
77+
}
78+
ucc_ee_executor_task_finalize(task->allgather_ring.etask);
79+
}
80+
task->super.status = status;
7381
out:
7482
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0);
7583
}
@@ -88,22 +96,50 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
8896
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
8997
size_t data_size = (count / tsize) * ucc_dt_size(dt);
9098
ucc_status_t status;
91-
ucc_rank_t block;
99+
ucc_rank_t sendto, recvfrom, sblock, rblock;
100+
ucc_ee_executor_t *exec;
101+
ucc_ee_executor_task_args_t eargs;
102+
void *buf;
92103

93104
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0);
94105
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
95106

107+
sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
108+
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
109+
sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0);
110+
rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0);
96111
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
97-
block = task->allgather_ring.get_send_block(&task->subset, trank, tsize,
98-
0);
99-
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
100-
sbuf, data_size, rmem, smem);
101-
if (ucc_unlikely(UCC_OK != status)) {
112+
status = ucc_coll_task_get_executor(&task->super, &exec);
113+
if (ucc_unlikely(status != UCC_OK)) {
102114
return status;
103115
}
116+
117+
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
118+
eargs.copy.src = sbuf;
119+
eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock);
120+
eargs.copy.len = data_size;
121+
122+
status = ucc_ee_executor_task_post(exec, &eargs,
123+
&task->allgather_ring.etask);
124+
if (ucc_unlikely(status != UCC_OK)) {
125+
return status;
126+
}
127+
buf = sbuf;
128+
} else {
129+
task->allgather_ring.etask = NULL;
130+
buf = PTR_OFFSET(rbuf, data_size * sblock);
104131
}
105132

133+
UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task),
134+
task, out);
135+
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size),
136+
data_size, rmem, recvfrom, team, task),
137+
task, out);
138+
106139
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
140+
141+
out:
142+
return status;
107143
}
108144

109145
ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
@@ -128,6 +164,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
128164
task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block;
129165
task->super.post = ucc_tl_ucp_allgather_ring_start;
130166
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
167+
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
168+
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
169+
}
131170

132171
return UCC_OK;
133172
}

src/components/tl/ucp/tl_ucp_coll.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ typedef struct ucc_tl_ucp_task {
180180
ucc_rank_t trank,
181181
ucc_rank_t tsize,
182182
int step);
183+
ucc_ee_executor_task_t *etask;
183184
} allgather_ring;
184185
struct {
185186
ucc_rank_t dist;

src/components/tl/ucp/tl_ucp_service_coll.c

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ static ucc_status_t ucc_tl_ucp_service_coll_stop_executor(ucc_coll_task_t *task)
7474
ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,
7575
void *rbuf, ucc_datatype_t dt,
7676
size_t count, ucc_reduction_op_t op,
77-
ucc_subset_t subset,
77+
ucc_subset_t subset,
7878
ucc_coll_task_t **task_p)
7979
{
8080
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
@@ -140,7 +140,7 @@ ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,
140140

141141
ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
142142
void *rbuf, size_t msgsize,
143-
ucc_subset_t subset,
143+
ucc_subset_t subset,
144144
ucc_coll_task_t **task_p)
145145
{
146146
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
@@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
178178
task->n_polls = npolls;
179179
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
180180
task->super.finalize = ucc_tl_ucp_coll_finalize;
181+
if (in_place) {
182+
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
183+
}
184+
185+
status = ucc_tl_ucp_service_coll_start_executor(&task->super);
186+
if (status != UCC_OK) {
187+
goto free_task;
188+
}
181189

182190
status = ucc_tl_ucp_allgather_ring_start(&task->super);
183191
if (status != UCC_OK) {
@@ -187,15 +195,16 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
187195
*task_p = &task->super;
188196
return status;
189197
finalize_coll:
190-
ucc_tl_ucp_coll_finalize(*task_p);
198+
ucc_tl_ucp_coll_finalize(&task->super);
199+
ucc_tl_ucp_service_coll_stop_executor(&task->super);
191200
free_task:
192201
ucc_tl_ucp_put_task(task);
193202
return status;
194203
}
195204

196205
ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
197206
size_t msgsize, ucc_rank_t root,
198-
ucc_subset_t subset,
207+
ucc_subset_t subset,
199208
ucc_coll_task_t **task_p)
200209
{
201210
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
@@ -239,7 +248,8 @@ ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
239248
return status;
240249
}
241250

242-
void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id) {
251+
void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id)
252+
{
243253
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
244254

245255
tl_team->super.super.params.id = id;

0 commit comments

Comments
 (0)