@@ -51,9 +51,8 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
51
51
args -> root : 0 ;
52
52
ucc_rank_t rank = VRANK (task -> subset .myrank , broot , size );
53
53
size_t local = GET_LOCAL_COUNT (args , size , rank );
54
- ucp_mem_h * mh_list = task -> mh_list ;
55
- int max_count = task -> count_mh ;
56
- int count_mh = 0 ;
54
+ ucp_mem_h * mh_list = task -> allgather_kn .mh_list ;
55
+ int max_mh = task -> allgather_kn .max_mh ;
57
56
void * sbuf ;
58
57
ptrdiff_t peer_seg_offset , local_seg_offset ;
59
58
ucc_rank_t peer , peer_dist ;
@@ -64,7 +63,6 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
64
63
65
64
EXEC_TASK_TEST (UCC_KN_PHASE_INIT , "failed during ee task test" ,
66
65
task -> allgather_kn .etask );
67
-
68
66
task -> allgather_kn .etask = NULL ;
69
67
UCC_KN_GOTO_PHASE (task -> allgather_kn .phase );
70
68
if (KN_NODE_EXTRA == node_type ) {
@@ -74,27 +72,27 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
74
72
local * dt_size , mem_type ,
75
73
ucc_ep_map_eval (task -> subset .map ,
76
74
INV_VRANK (peer ,broot ,size )),
77
- team , task , mh_list [count_mh ++ ]),
75
+ team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
78
76
task , out );
79
- ucc_assert (count_mh >= max_count );
77
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
80
78
81
79
}
82
80
UCPCHECK_GOTO (ucc_tl_ucp_send_nb_with_mem (rbuf , data_size , mem_type ,
83
81
ucc_ep_map_eval (task -> subset .map ,
84
82
INV_VRANK (peer ,broot ,size )),
85
- team , task , mh_list [count_mh ++ ]),
83
+ team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
86
84
task , out );
87
- ucc_assert (count_mh >= max_count );
85
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
88
86
}
89
87
if ((p -> type != KN_PATTERN_ALLGATHERX ) && (node_type == KN_NODE_PROXY )) {
90
88
peer = ucc_knomial_pattern_get_extra (p , rank );
91
89
extra_count = GET_LOCAL_COUNT (args , size , peer );
92
90
peer = ucc_ep_map_eval (task -> subset .map , peer );
93
91
UCPCHECK_GOTO (ucc_tl_ucp_recv_nb_with_mem (PTR_OFFSET (task -> allgather_kn .sbuf ,
94
92
local * dt_size ), extra_count * dt_size ,
95
- mem_type , peer , team , task , mh_list [count_mh ++ ]),
93
+ mem_type , peer , team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
96
94
task , out );
97
- ucc_assert (count_mh >= max_count );
95
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
98
96
}
99
97
100
98
UCC_KN_PHASE_EXTRA :
@@ -123,13 +121,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
123
121
continue ;
124
122
}
125
123
}
124
+ printf ("progress : count_mh: %d, mh: %lx\n" , task -> allgather_kn .count_mh , (unsigned long )mh_list [task -> allgather_kn .count_mh ]);
126
125
UCPCHECK_GOTO (ucc_tl_ucp_send_nb_with_mem (sbuf , local_seg_count * dt_size ,
127
126
mem_type ,
128
127
ucc_ep_map_eval (task -> subset .map ,
129
128
INV_VRANK (peer , broot , size )),
130
- team , task , mh_list [count_mh ++ ]),
129
+ team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
131
130
task , out );
132
- ucc_assert (count_mh >= max_count );
131
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
133
132
}
134
133
135
134
for (loop_step = 1 ; loop_step < radix ; loop_step ++ ) {
@@ -151,9 +150,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
151
150
peer_seg_count * dt_size , mem_type ,
152
151
ucc_ep_map_eval (task -> subset .map ,
153
152
INV_VRANK (peer , broot , size )),
154
- team , task , mh_list [count_mh ++ ]),
153
+ team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
155
154
task , out );
156
- ucc_assert (count_mh >= max_count );
155
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
157
156
}
158
157
UCC_KN_PHASE_LOOP :
159
158
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks (task )) {
@@ -169,9 +168,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
169
168
mem_type ,
170
169
ucc_ep_map_eval (task -> subset .map ,
171
170
INV_VRANK (peer , broot , size )),
172
- team , task , mh_list [count_mh ++ ]),
171
+ team , task , mh_list [task -> allgather_kn . count_mh ++ ]),
173
172
task , out );
174
- ucc_assert (count_mh >= max_count );
173
+ ucc_assert (task -> allgather_kn . count_mh >= max_mh );
175
174
}
176
175
UCC_KN_PHASE_PROXY :
177
176
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks (task )) {
@@ -180,6 +179,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
180
179
}
181
180
182
181
out :
182
+ ucc_assert (task -> allgather_kn .count_mh - 1 == max_mh );
183
183
ucc_assert (UCC_TL_UCP_TASK_P2P_COMPLETE (task ));
184
184
task -> super .status = UCC_OK ;
185
185
UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_kn_done" , 0 );
@@ -205,6 +205,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
205
205
206
206
UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_kn_start" , 0 );
207
207
ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
208
+ task -> allgather_kn .etask = NULL ;
208
209
task -> allgather_kn .phase = UCC_KN_PHASE_INIT ;
209
210
if (ct == UCC_COLL_TYPE_ALLGATHER ) {
210
211
ucc_kn_ag_pattern_init (size , rank , radix , args -> dst .info .count ,
@@ -245,7 +246,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
245
246
return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
246
247
}
247
248
248
- void register_memory (ucc_coll_task_t * coll_task ){
249
+ ucc_status_t register_memory (ucc_coll_task_t * coll_task ){
249
250
250
251
ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task ,
251
252
ucc_tl_ucp_task_t );
@@ -283,10 +284,9 @@ void register_memory(ucc_coll_task_t *coll_task){
283
284
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
284
285
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE ;
285
286
mmap_params .memory_type = ucc_memtype_to_ucs [mem_type ];
286
-
287
+ printf ( "I'm in register memory" );
287
288
if (KN_NODE_EXTRA == node_type ) {
288
289
if (p -> type != KN_PATTERN_ALLGATHERX ) {
289
-
290
290
mmap_params .address = task -> allgather_kn .sbuf ;
291
291
mmap_params .length = local * dt_size ;
292
292
MEM_MAP ();
@@ -310,11 +310,13 @@ void register_memory(ucc_coll_task_t *coll_task){
310
310
goto out ;
311
311
}
312
312
while (!ucc_knomial_pattern_loop_done (p )) {
313
+ printf ("in the while loop" );
313
314
ucc_kn_ag_pattern_peer_seg (rank , p , & local_seg_count ,
314
315
& local_seg_offset );
315
316
sbuf = PTR_OFFSET (rbuf , local_seg_offset * dt_size );
316
317
317
318
for (loop_step = radix - 1 ; loop_step > 0 ; loop_step -- ) {
319
+ printf ("in the for loop" );
318
320
peer = ucc_knomial_pattern_get_loop_peer (p , rank , loop_step );
319
321
if (peer == UCC_KN_PEER_NULL )
320
322
continue ;
@@ -327,6 +329,7 @@ void register_memory(ucc_coll_task_t *coll_task){
327
329
}
328
330
mmap_params .address = sbuf ;
329
331
mmap_params .length = local_seg_count * dt_size ;
332
+ printf ("register memory : count_mh: %d, mh: %lx\n" , count_mh , (unsigned long )mh_list [count_mh ]);
330
333
MEM_MAP ();
331
334
}
332
335
@@ -358,35 +361,57 @@ void register_memory(ucc_coll_task_t *coll_task){
358
361
}
359
362
360
363
out :
361
- task -> mh_list = mh_list ;
362
- task -> count_mh = count_mh - 1 ;
364
+ task -> allgather_kn .mh_list = mh_list ;
365
+ task -> allgather_kn .max_mh = count_mh - 1 ;
366
+ task -> allgather_kn .count_mh = 0 ;
367
+ return UCC_OK ;
363
368
}
364
369
370
+ ucc_status_t ucc_tl_ucp_allgather_knomial_finalize (ucc_coll_task_t * coll_task ){
371
+ ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task ,
372
+ ucc_tl_ucp_task_t );
373
+
374
+ ucc_mpool_cleanup (& task -> allgather_kn .etask_node_mpool , 1 );
375
+ free (task -> allgather_kn .mh_list );
376
+
377
+ return UCC_OK ;
378
+ };
379
+
365
380
ucc_status_t ucc_tl_ucp_allgather_knomial_init_r (
366
381
ucc_base_coll_args_t * coll_args , ucc_base_team_t * team ,
367
382
ucc_coll_task_t * * task_h , ucc_kn_radix_t radix )
368
383
{
369
384
ucc_tl_ucp_team_t * tl_team = ucc_derived_of (team , ucc_tl_ucp_team_t );
370
385
ucc_tl_ucp_task_t * task ;
371
386
ucc_sbgp_t * sbgp ;
387
+ ucc_status_t status ;
372
388
373
389
task = ucc_tl_ucp_init_task (coll_args , team );
374
- ucc_mpool_init (& task -> allgather_kn .etask_node_mpool , 0 , sizeof (node_ucc_ee_executor_task_t ),
390
+ status = ucc_mpool_init (& task -> allgather_kn .etask_node_mpool , 0 , sizeof (node_ucc_ee_executor_task_t ),
375
391
0 , UCC_CACHE_LINE_SIZE , 16 , UINT_MAX , NULL ,
376
392
tl_team -> super .super .context -> ucc_context -> thread_mode , "etasks_linked_list_nodes" );
393
+ if (status < 0 ){
394
+ tl_error (UCC_TASK_LIB (task ),
395
+ "failed to initialize ucc_mpool" );
396
+ }
377
397
378
398
if (tl_team -> cfg .use_reordering &&
379
399
coll_args -> args .coll_type == UCC_COLL_TYPE_ALLREDUCE ) {
380
400
sbgp = ucc_topo_get_sbgp (tl_team -> topo , UCC_SBGP_FULL_HOST_ORDERED );
381
401
task -> subset .myrank = sbgp -> group_rank ;
382
402
task -> subset .map = sbgp -> map ;
383
403
}
384
- register_memory (& task -> super );
404
+ status = register_memory (& task -> super );
405
+ if (status < 0 ){
406
+ tl_error (UCC_TASK_LIB (task ),
407
+ "failed to register memory" );
408
+ }
385
409
task -> allgather_kn .etask_linked_list_head = NULL ;
386
410
task -> allgather_kn .p .radix = radix ;
387
411
task -> super .flags |= UCC_COLL_TASK_FLAG_EXECUTOR ;
388
412
task -> super .post = ucc_tl_ucp_allgather_knomial_start ;
389
413
task -> super .progress = ucc_tl_ucp_allgather_knomial_progress ;
414
+ task -> super .finalize = ucc_tl_ucp_allgather_knomial_finalize ;
390
415
* task_h = & task -> super ;
391
416
return UCC_OK ;
392
417
}
0 commit comments