Skip to content

Commit 00a9097

Browse files
committed
fix all of the comments, still same bug
1 parent 4cbdc7e commit 00a9097

File tree

4 files changed

+89
-48
lines changed

4 files changed

+89
-48
lines changed

src/components/tl/ucp/allgather/allgather_knomial.c

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
5151
args->root : 0;
5252
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
5353
size_t local = GET_LOCAL_COUNT(args, size, rank);
54-
ucp_mem_h *mh_list = task->mh_list;
55-
int max_count = task->count_mh;
56-
int count_mh = 0;
54+
ucp_mem_h *mh_list = task->allgather_kn.mh_list;
55+
int max_mh = task->allgather_kn.max_mh;
5756
void *sbuf;
5857
ptrdiff_t peer_seg_offset, local_seg_offset;
5958
ucc_rank_t peer, peer_dist;
@@ -64,7 +63,6 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
6463

6564
EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test",
6665
task->allgather_kn.etask);
67-
6866
task->allgather_kn.etask = NULL;
6967
UCC_KN_GOTO_PHASE(task->allgather_kn.phase);
7068
if (KN_NODE_EXTRA == node_type) {
@@ -74,27 +72,27 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
7472
local * dt_size, mem_type,
7573
ucc_ep_map_eval(task->subset.map,
7674
INV_VRANK(peer,broot,size)),
77-
team, task, mh_list[count_mh++]),
75+
team, task, mh_list[task->allgather_kn.count_mh++]),
7876
task, out);
79-
ucc_assert(count_mh >= max_count);
77+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
8078

8179
}
8280
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type,
8381
ucc_ep_map_eval(task->subset.map,
8482
INV_VRANK(peer,broot,size)),
85-
team, task, mh_list[count_mh++]),
83+
team, task, mh_list[task->allgather_kn.count_mh++]),
8684
task, out);
87-
ucc_assert(count_mh >= max_count);
85+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
8886
}
8987
if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) {
9088
peer = ucc_knomial_pattern_get_extra(p, rank);
9189
extra_count = GET_LOCAL_COUNT(args, size, peer);
9290
peer = ucc_ep_map_eval(task->subset.map, peer);
9391
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(task->allgather_kn.sbuf,
9492
local * dt_size), extra_count * dt_size,
95-
mem_type, peer, team, task, mh_list[count_mh++]),
93+
mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]),
9694
task, out);
97-
ucc_assert(count_mh >= max_count);
95+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
9896
}
9997

10098
UCC_KN_PHASE_EXTRA:
@@ -123,13 +121,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
123121
continue;
124122
}
125123
}
124+
printf("progress : count_mh: %d, mh: %lx\n", task->allgather_kn.count_mh, (unsigned long)mh_list[task->allgather_kn.count_mh]);
126125
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size,
127126
mem_type,
128127
ucc_ep_map_eval(task->subset.map,
129128
INV_VRANK(peer, broot, size)),
130-
team, task, mh_list[count_mh++]),
129+
team, task, mh_list[task->allgather_kn.count_mh++]),
131130
task, out);
132-
ucc_assert(count_mh >= max_count);
131+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
133132
}
134133

135134
for (loop_step = 1; loop_step < radix; loop_step++) {
@@ -151,9 +150,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
151150
peer_seg_count * dt_size, mem_type,
152151
ucc_ep_map_eval(task->subset.map,
153152
INV_VRANK(peer, broot, size)),
154-
team, task, mh_list[count_mh++]),
153+
team, task, mh_list[task->allgather_kn.count_mh++]),
155154
task, out);
156-
ucc_assert(count_mh >= max_count);
155+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
157156
}
158157
UCC_KN_PHASE_LOOP:
159158
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) {
@@ -169,9 +168,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
169168
mem_type,
170169
ucc_ep_map_eval(task->subset.map,
171170
INV_VRANK(peer, broot, size)),
172-
team, task, mh_list[count_mh++]),
171+
team, task, mh_list[task->allgather_kn.count_mh++]),
173172
task, out);
174-
ucc_assert(count_mh >= max_count);
173+
ucc_assert(task->allgather_kn.count_mh >= max_mh);
175174
}
176175
UCC_KN_PHASE_PROXY:
177176
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) {
@@ -180,6 +179,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
180179
}
181180

182181
out:
182+
ucc_assert(task->allgather_kn.count_mh-1 == max_mh);
183183
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
184184
task->super.status = UCC_OK;
185185
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0);
@@ -205,6 +205,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
205205

206206
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
207207
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
208+
task->allgather_kn.etask = NULL;
208209
task->allgather_kn.phase = UCC_KN_PHASE_INIT;
209210
if (ct == UCC_COLL_TYPE_ALLGATHER) {
210211
ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count,
@@ -245,7 +246,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
245246
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
246247
}
247248

248-
void register_memory(ucc_coll_task_t *coll_task){
249+
ucc_status_t register_memory(ucc_coll_task_t *coll_task){
249250

250251
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
251252
ucc_tl_ucp_task_t);
@@ -283,10 +284,9 @@ void register_memory(ucc_coll_task_t *coll_task){
283284
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
284285
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
285286
mmap_params.memory_type = ucc_memtype_to_ucs[mem_type];
286-
287+
printf("I'm in register memory");
287288
if (KN_NODE_EXTRA == node_type) {
288289
if (p->type != KN_PATTERN_ALLGATHERX) {
289-
290290
mmap_params.address = task->allgather_kn.sbuf;
291291
mmap_params.length = local * dt_size;
292292
MEM_MAP();
@@ -310,11 +310,13 @@ void register_memory(ucc_coll_task_t *coll_task){
310310
goto out;
311311
}
312312
while (!ucc_knomial_pattern_loop_done(p)) {
313+
printf("in the while loop");
313314
ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count,
314315
&local_seg_offset);
315316
sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size);
316317

317318
for (loop_step = radix - 1; loop_step > 0; loop_step--) {
319+
printf("in the for loop");
318320
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
319321
if (peer == UCC_KN_PEER_NULL)
320322
continue;
@@ -327,6 +329,7 @@ void register_memory(ucc_coll_task_t *coll_task){
327329
}
328330
mmap_params.address = sbuf;
329331
mmap_params.length = local_seg_count * dt_size;
332+
printf("register memory : count_mh: %d, mh: %lx\n", count_mh, (unsigned long)mh_list[count_mh]);
330333
MEM_MAP();
331334
}
332335

@@ -358,35 +361,57 @@ void register_memory(ucc_coll_task_t *coll_task){
358361
}
359362

360363
out:
361-
task->mh_list = mh_list;
362-
task->count_mh = count_mh-1;
364+
task->allgather_kn.mh_list = mh_list;
365+
task->allgather_kn.max_mh = count_mh-1;
366+
task->allgather_kn.count_mh = 0;
367+
return UCC_OK;
363368
}
364369

370+
ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){
371+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
372+
ucc_tl_ucp_task_t);
373+
374+
ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1);
375+
free(task->allgather_kn.mh_list);
376+
377+
return UCC_OK;
378+
};
379+
365380
ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
366381
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
367382
ucc_coll_task_t **task_h, ucc_kn_radix_t radix)
368383
{
369384
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
370385
ucc_tl_ucp_task_t *task;
371386
ucc_sbgp_t *sbgp;
387+
ucc_status_t status;
372388

373389
task = ucc_tl_ucp_init_task(coll_args, team);
374-
ucc_mpool_init(&task->allgather_kn.etask_node_mpool, 0, sizeof(node_ucc_ee_executor_task_t),
390+
status = ucc_mpool_init(&task->allgather_kn.etask_node_mpool, 0, sizeof(node_ucc_ee_executor_task_t),
375391
0, UCC_CACHE_LINE_SIZE, 16, UINT_MAX, NULL,
376392
tl_team->super.super.context->ucc_context->thread_mode, "etasks_linked_list_nodes");
393+
if (status < 0){
394+
tl_error(UCC_TASK_LIB(task),
395+
"failed to initialize ucc_mpool");
396+
}
377397

378398
if (tl_team->cfg.use_reordering &&
379399
coll_args->args.coll_type == UCC_COLL_TYPE_ALLREDUCE) {
380400
sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED);
381401
task->subset.myrank = sbgp->group_rank;
382402
task->subset.map = sbgp->map;
383403
}
384-
register_memory(&task->super);
404+
status = register_memory(&task->super);
405+
if (status < 0){
406+
tl_error(UCC_TASK_LIB(task),
407+
"failed to register memory");
408+
}
385409
task->allgather_kn.etask_linked_list_head = NULL;
386410
task->allgather_kn.p.radix = radix;
387411
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
388412
task->super.post = ucc_tl_ucp_allgather_knomial_start;
389413
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
414+
task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize;
390415
*task_h = &task->super;
391416
return UCC_OK;
392417
}

src/components/tl/ucp/tl_ucp_coll.h

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ void ucc_tl_ucp_team_default_score_str_free(
6060
#define MEM_MAP() do { \
6161
status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); \
6262
if (UCC_OK != status) { \
63-
task->super.status = status; \
64-
return; \
63+
return status; \
6564
} \
6665
if (count_mh == size_of_list){ \
6766
size_of_list *= 2; \
@@ -109,8 +108,6 @@ typedef struct ucc_tl_ucp_allreduce_sw_host_allgather
109108
typedef struct ucc_tl_ucp_task {
110109
ucc_coll_task_t super;
111110
uint32_t flags;
112-
ucp_mem_h *mh_list;
113-
int count_mh;
114111
union {
115112
struct {
116113
uint32_t send_posted;
@@ -197,10 +194,13 @@ typedef struct ucc_tl_ucp_task {
197194
int phase;
198195
ucc_knomial_pattern_t p;
199196
void *sbuf;
197+
ucc_ee_executor_task_t *etask;
200198
node_ucc_ee_executor_task_t *etask_linked_list_head;
201199
ucc_rank_t recv_dist;
202200
ucc_mpool_t etask_node_mpool;
203-
ucc_ee_executor_task_t *etask;
201+
ucp_mem_h *mh_list;
202+
int count_mh;
203+
int max_mh;
204204
} allgather_kn;
205205
struct {
206206
/*
@@ -427,27 +427,34 @@ static inline ucc_status_t ucc_tl_ucp_test_with_etasks(ucc_tl_ucp_task_t *task)
427427
{
428428
int polls = 0;
429429
ucc_status_t status;
430+
ucc_status_t status_2;
431+
node_ucc_ee_executor_task_t *current_node;
432+
node_ucc_ee_executor_task_t *prev_node;
430433

434+
if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) {
435+
return UCC_OK;
436+
}
431437
while (polls++ < task->n_polls) {
432-
node_ucc_ee_executor_task_t *current_node;
433-
node_ucc_ee_executor_task_t *prev_node;
434438
current_node = task->allgather_kn.etask_linked_list_head;
435439
prev_node = NULL;
436440
while(current_node != NULL) {
437441
status = ucc_ee_executor_task_test(current_node->etask);
438442
if (status > 0) {
439-
ucp_memcpy_device_complete(current_node->etask->completion, status);
440-
ucc_ee_executor_task_finalize(current_node->etask);
443+
ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status));
444+
status_2 = ucc_ee_executor_task_finalize(current_node->etask);
445+
ucc_mpool_put(current_node);
446+
if (ucc_unlikely(status_2 < 0)){
447+
tl_error(UCC_TASK_LIB(task), "task finalize didnt work");
448+
return status_2;
449+
}
441450
if (prev_node != NULL){
442451
prev_node->next = current_node->next; //to remove from list
443452
}
444453
else{ //i'm on first node
445454
task->allgather_kn.etask_linked_list_head = current_node->next;
446455
}
447-
}
448-
else {
449-
prev_node = current_node;
450456
}
457+
prev_node = current_node;
451458
current_node = current_node->next; //to iterate to next node
452459
}
453460
if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && task->allgather_kn.etask_linked_list_head == NULL) {
@@ -483,17 +490,26 @@ static inline ucc_status_t ucc_tl_ucp_test_recv(ucc_tl_ucp_task_t *task)
483490
static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *task) {
484491
int polls = 0;
485492
ucc_status_t status;
493+
ucc_status_t status_2;
494+
node_ucc_ee_executor_task_t *current_node;
495+
node_ucc_ee_executor_task_t *prev_node;
486496

497+
if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) {
498+
return UCC_OK;
499+
}
487500
while (polls++ < task->n_polls) {
488-
node_ucc_ee_executor_task_t *current_node;
489-
node_ucc_ee_executor_task_t *prev_node;
490501
current_node = task->allgather_kn.etask_linked_list_head;
491502
prev_node = NULL;
492503
while(current_node != NULL) {
493-
status = ucc_ee_executor_task_test(current_node->etask); \
504+
status = ucc_ee_executor_task_test(current_node->etask);
494505
if (status > 0) {
495-
ucp_memcpy_device_complete(current_node->etask->completion, status); \
496-
ucc_ee_executor_task_finalize(current_node->etask); \
506+
ucp_memcpy_device_complete(current_node->etask->completion, status);
507+
status_2 = ucc_ee_executor_task_finalize(current_node->etask);
508+
ucc_mpool_put(current_node);
509+
if (ucc_unlikely(status_2 < 0)){
510+
tl_error(UCC_TASK_LIB(task), "task finalize didnt work");
511+
return status_2;
512+
}
497513
if (prev_node != NULL){
498514
prev_node->next = current_node->next; //to remove from list
499515
}

src/components/tl/ucp/tl_ucp_context.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "utils/arch/cpu.h"
1414
#include "schedule/ucc_schedule_pipelined.h"
1515
#include <limits.h>
16-
#include <ucp/api/ucp.h>
1716

1817
#define UCP_CHECK(function, msg, go, ctx) \
1918
status = function; \
@@ -144,7 +143,6 @@ static int memcpy_device_start(void *dest, void *src, size_t size,
144143

145144
status = ucc_coll_task_get_executor(&task->super, &exec);
146145
if (ucc_unlikely(status != UCC_OK)) {
147-
task->super.status = status;
148146
return status;
149147
}
150148

@@ -154,8 +152,12 @@ static int memcpy_device_start(void *dest, void *src, size_t size,
154152
eargs.copy.len = size;
155153
node_ucc_ee_executor_task_t *new_node;
156154
new_node = ucc_mpool_get(&task->allgather_kn.etask_node_mpool);
155+
if (ucc_unlikely(!new_node)) {
156+
return UCC_ERR_NO_MEMORY;
157+
}
157158
status = ucc_ee_executor_task_post(exec, &eargs,
158159
&new_node->etask);
160+
task->allgather_kn.etask_linked_list_head->etask->completion = completion;
159161

160162
if (ucc_unlikely(status != UCC_OK)) {
161163
task->super.status = status;
@@ -164,7 +166,6 @@ static int memcpy_device_start(void *dest, void *src, size_t size,
164166
new_node->next = task->allgather_kn.etask_linked_list_head;
165167
task->allgather_kn.etask_linked_list_head = new_node;
166168

167-
task->allgather_kn.etask_linked_list_head->etask->completion = completion;
168169
return 1;
169170

170171
}
@@ -179,7 +180,6 @@ static int memcpy_device(void *dest, void *src, size_t size, void *user_data){
179180

180181
status = ucc_coll_task_get_executor(&task->super, &exec);
181182
if (ucc_unlikely(status != UCC_OK)) {
182-
task->super.status = status;
183183
return status;
184184
}
185185

@@ -190,18 +190,19 @@ static int memcpy_device(void *dest, void *src, size_t size, void *user_data){
190190

191191
status = ucc_ee_executor_task_post(exec, &eargs, &etask);
192192
if (ucc_unlikely(status < 0)) {
193-
task->super.status = status;
194193
return status;
195194
}
196195
status = ucc_ee_executor_task_test(etask);
197196
while (status>0) {
198197
status = ucc_ee_executor_task_test(etask);
199198
if (ucc_unlikely(status < 0)) {
200-
task->super.status = status;
201199
return status;
202200
}
203201
}
204-
ucc_ee_executor_task_finalize(etask);
202+
status = ucc_ee_executor_task_finalize(etask);
203+
if (ucc_unlikely(status < 0)) {
204+
return status;
205+
}
205206
return 1;
206207
}
207208

src/schedule/ucc_schedule.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#ifndef UCC_SCHEDULE_H_
88
#define UCC_SCHEDULE_H_
99

10-
#include <ucp/api/ucp.h>
1110
#include "ucc/api/ucc.h"
1211
#include "utils/ucc_list.h"
1312
#include "utils/ucc_log.h"

0 commit comments

Comments
 (0)