Skip to content

Commit 6142008

Browse files
committed
working version on IL1 with Cuda
1 parent 00a9097 commit 6142008

File tree

2 files changed

+44
-30
lines changed

2 files changed

+44
-30
lines changed

src/components/tl/ucp/allgather/allgather_knomial.c

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
7474
INV_VRANK(peer,broot,size)),
7575
team, task, mh_list[task->allgather_kn.count_mh++]),
7676
task, out);
77-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
77+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
7878

7979
}
8080
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type,
8181
ucc_ep_map_eval(task->subset.map,
8282
INV_VRANK(peer,broot,size)),
8383
team, task, mh_list[task->allgather_kn.count_mh++]),
8484
task, out);
85-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
85+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
8686
}
8787
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
8888
peer = ucc_knomial_pattern_get_extra(p, rank);
@@ -92,7 +92,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
9292
local * dt_size), extra_count * dt_size,
9393
mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]),
9494
task, out);
95-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
95+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
9696
}
9797

9898
UCC_KN_PHASE_EXTRA:
@@ -121,14 +121,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
121121
continue;
122122
}
123123
}
124-
printf("progress : count_mh: %d, mh: %lx\n", task->allgather_kn.count_mh, (unsigned long)mh_list[task->allgather_kn.count_mh]);
125124
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size,
126125
mem_type,
127126
ucc_ep_map_eval(task->subset.map,
128127
INV_VRANK(peer, broot, size)),
129128
team, task, mh_list[task->allgather_kn.count_mh++]),
130129
task, out);
131-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
130+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
132131
}
133132

134133
for (loop_step = 1; loop_step < radix; loop_step++) {
@@ -152,7 +151,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
152151
INV_VRANK(peer, broot, size)),
153152
team, task, mh_list[task->allgather_kn.count_mh++]),
154153
task, out);
155-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
154+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
156155
}
157156
UCC_KN_PHASE_LOOP:
158157
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) {
@@ -170,7 +169,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
170169
INV_VRANK(peer, broot, size)),
171170
team, task, mh_list[task->allgather_kn.count_mh++]),
172171
task, out);
173-
ucc_assert(task->allgather_kn.count_mh >= max_mh);
172+
ucc_assert(task->allgather_kn.count_mh-1 <= max_mh);
174173
}
175174
UCC_KN_PHASE_PROXY:
176175
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) {
@@ -252,6 +251,7 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
252251
ucc_tl_ucp_task_t);
253252
ucc_coll_args_t *args = &TASK_ARGS(task);
254253
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
254+
ucc_coll_type_t ct = args->coll_type;
255255
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
256256
uint8_t node_type = task->allgather_kn.p.node_type;
257257
ucc_knomial_pattern_t *p = &task->allgather_kn.p;
@@ -273,18 +273,28 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
273273
ucc_status_t status;
274274
size_t extra_count;
275275

276-
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
277-
ucp_mem_map_params_t mmap_params;
278-
ucp_mem_h mh;
279-
int size_of_list = 1;
280-
int count_mh = 0;
281-
ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h));
276+
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
277+
ucp_mem_map_params_t mmap_params;
278+
// ucp_mem_h mh;
279+
int size_of_list = 1;
280+
int count_mh = 0;
281+
ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h));
282+
283+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
284+
task->allgather_kn.etask = NULL;
285+
task->allgather_kn.phase = UCC_KN_PHASE_INIT;
286+
if (ct == UCC_COLL_TYPE_ALLGATHER) {
287+
ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count,
288+
&task->allgather_kn.p);
289+
} else {
290+
ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count,
291+
&task->allgather_kn.p);
292+
}
282293

283294
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
284295
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
285296
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
286297
mmap_params.memory_type = ucc_memtype_to_ucs[mem_type];
287-
printf("I'm in register memory");
288298
if (KN_NODE_EXTRA == node_type) {
289299
if (p->type != KN_PATTERN_ALLGATHERX) {
290300
mmap_params.address = task->allgather_kn.sbuf;
@@ -310,13 +320,10 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
310320
goto out;
311321
}
312322
while (!ucc_knomial_pattern_loop_done(p)) {
313-
printf("in the while loop");
314323
ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count,
315324
&local_seg_offset);
316325
sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size);
317-
318326
for (loop_step = radix - 1; loop_step > 0; loop_step--) {
319-
printf("in the for loop");
320327
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
321328
if (peer == UCC_KN_PEER_NULL)
322329
continue;
@@ -329,7 +336,6 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
329336
}
330337
mmap_params.address = sbuf;
331338
mmap_params.length = local_seg_count * dt_size;
332-
printf("register memory : count_mh: %d, mh: %lx\n", count_mh, (unsigned long)mh_list[count_mh]);
333339
MEM_MAP();
334340
}
335341

@@ -370,12 +376,23 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
370376
ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){
371377
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
372378
ucc_tl_ucp_task_t);
379+
ucc_status_t status;
380+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
381+
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
373382

374383
ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1);
384+
for (int i=0; i<task->allgather_kn.max_mh+1; i++){
385+
ucp_mem_unmap(ctx->worker.ucp_context, task->allgather_kn.mh_list[i]);
386+
}
375387
free(task->allgather_kn.mh_list);
388+
status = ucc_tl_ucp_coll_finalize(&task->super);
389+
if (status < 0){
390+
tl_error(UCC_TASK_LIB(task),
391+
"failed to initialize ucc_mpool");
392+
}
376393

377394
return UCC_OK;
378-
};
395+
}
379396

380397
ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
381398
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
@@ -401,17 +418,17 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
401418
task->subset.myrank = sbgp->group_rank;
402419
task->subset.map = sbgp->map;
403420
}
404-
status = register_memory(&task->super);
405-
if (status < 0){
406-
tl_error(UCC_TASK_LIB(task),
407-
"failed to register memory");
408-
}
409421
task->allgather_kn.etask_linked_list_head = NULL;
410422
task->allgather_kn.p.radix = radix;
411423
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
412424
task->super.post = ucc_tl_ucp_allgather_knomial_start;
413425
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
414426
task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize;
427+
status = register_memory(&task->super);
428+
if (status < 0){
429+
tl_error(UCC_TASK_LIB(task),
430+
"failed to register memory");
431+
}
415432
*task_h = &task->super;
416433
return UCC_OK;
417434
}

src/components/tl/ucp/tl_ucp_coll.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@ void ucc_tl_ucp_team_default_score_str_free(
5858
} while(0)
5959

6060
#define MEM_MAP() do { \
61-
status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); \
61+
status = ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh_list[count_mh++])); \
6262
if (UCC_OK != status) { \
6363
return status; \
6464
} \
6565
if (count_mh == size_of_list){ \
6666
size_of_list *= 2; \
6767
mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); \
6868
} \
69-
mh_list[count_mh++] = mh; \
7069
} while(0)
7170

7271
#define EXEC_TASK_WAIT(_etask, ...) \
@@ -503,7 +502,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t
503502
while(current_node != NULL) {
504503
status = ucc_ee_executor_task_test(current_node->etask);
505504
if (status > 0) {
506-
ucp_memcpy_device_complete(current_node->etask->completion, status);
505+
ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status));
507506
status_2 = ucc_ee_executor_task_finalize(current_node->etask);
508507
ucc_mpool_put(current_node);
509508
if (ucc_unlikely(status_2 < 0)){
@@ -517,9 +516,7 @@ static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *t
517516
task->allgather_kn.etask_linked_list_head = current_node->next;
518517
}
519518
}
520-
else {
521-
prev_node = current_node;
522-
}
519+
prev_node = current_node;
523520
current_node = current_node->next; //to iterate to next node
524521
}
525522
if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) {

0 commit comments

Comments
 (0)