@@ -74,15 +74,15 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
74
74
INV_VRANK (peer ,broot ,size )),
75
75
team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
76
76
task , out );
77
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
77
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
78
78
79
79
}
80
80
UCPCHECK_GOTO (ucc_tl_ucp_send_nb_with_mem (rbuf , data_size , mem_type ,
81
81
ucc_ep_map_eval (task -> subset .map ,
82
82
INV_VRANK (peer ,broot ,size )),
83
83
team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
84
84
task , out );
85
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
85
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
86
86
}
87
87
if ((p -> type != KN_PATTERN_ALLGATHERX ) && (node_type == KN_NODE_PROXY )) {
88
88
peer = ucc_knomial_pattern_get_extra (p , rank );
@@ -92,7 +92,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
92
92
local * dt_size ), extra_count * dt_size ,
93
93
mem_type , peer , team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
94
94
task , out );
95
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
95
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
96
96
}
97
97
98
98
UCC_KN_PHASE_EXTRA :
@@ -121,14 +121,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
121
121
continue ;
122
122
}
123
123
}
124
- printf ("progress : count_mh: %d, mh: %lx\n" , task -> allgather_kn .count_mh , (unsigned long )mh_list [task -> allgather_kn .count_mh ]);
125
124
UCPCHECK_GOTO (ucc_tl_ucp_send_nb_with_mem (sbuf , local_seg_count * dt_size ,
126
125
mem_type ,
127
126
ucc_ep_map_eval (task -> subset .map ,
128
127
INV_VRANK (peer , broot , size )),
129
128
team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
130
129
task , out );
131
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
130
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
132
131
}
133
132
134
133
for (loop_step = 1 ; loop_step < radix ; loop_step ++ ) {
@@ -152,7 +151,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
152
151
INV_VRANK (peer , broot , size )),
153
152
team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
154
153
task , out );
155
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
154
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
156
155
}
157
156
UCC_KN_PHASE_LOOP :
158
157
if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks (task )) {
@@ -170,7 +169,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
170
169
INV_VRANK (peer , broot , size )),
171
170
team , task , mh_list [task -> allgather_kn .count_mh ++ ]),
172
171
task , out );
173
- ucc_assert (task -> allgather_kn .count_mh > = max_mh );
172
+ ucc_assert (task -> allgather_kn .count_mh - 1 < = max_mh );
174
173
}
175
174
UCC_KN_PHASE_PROXY :
176
175
if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks (task )) {
@@ -252,6 +251,7 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
252
251
ucc_tl_ucp_task_t );
253
252
ucc_coll_args_t * args = & TASK_ARGS (task );
254
253
ucc_tl_ucp_team_t * team = TASK_TEAM (task );
254
+ ucc_coll_type_t ct = args -> coll_type ;
255
255
ucc_kn_radix_t radix = task -> allgather_kn .p .radix ;
256
256
uint8_t node_type = task -> allgather_kn .p .node_type ;
257
257
ucc_knomial_pattern_t * p = & task -> allgather_kn .p ;
@@ -273,18 +273,28 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
273
273
ucc_status_t status ;
274
274
size_t extra_count ;
275
275
276
- ucc_tl_ucp_context_t * ctx = UCC_TL_UCP_TEAM_CTX (team );
277
- ucp_mem_map_params_t mmap_params ;
278
- ucp_mem_h mh ;
279
- int size_of_list = 1 ;
280
- int count_mh = 0 ;
281
- ucp_mem_h * mh_list = (ucp_mem_h * )malloc (size_of_list * sizeof (ucp_mem_h ));
276
+ ucc_tl_ucp_context_t * ctx = UCC_TL_UCP_TEAM_CTX (team );
277
+ ucp_mem_map_params_t mmap_params ;
278
+ // ucp_mem_h mh;
279
+ int size_of_list = 1 ;
280
+ int count_mh = 0 ;
281
+ ucp_mem_h * mh_list = (ucp_mem_h * )malloc (size_of_list * sizeof (ucp_mem_h ));
282
+
283
+ UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_kn_start" , 0 );
284
+ task -> allgather_kn .etask = NULL ;
285
+ task -> allgather_kn .phase = UCC_KN_PHASE_INIT ;
286
+ if (ct == UCC_COLL_TYPE_ALLGATHER ) {
287
+ ucc_kn_ag_pattern_init (size , rank , radix , args -> dst .info .count ,
288
+ & task -> allgather_kn .p );
289
+ } else {
290
+ ucc_kn_agx_pattern_init (size , rank , radix , args -> dst .info .count ,
291
+ & task -> allgather_kn .p );
292
+ }
282
293
283
294
mmap_params .field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
284
295
UCP_MEM_MAP_PARAM_FIELD_LENGTH |
285
296
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE ;
286
297
mmap_params .memory_type = ucc_memtype_to_ucs [mem_type ];
287
- printf ("I'm in register memory" );
288
298
if (KN_NODE_EXTRA == node_type ) {
289
299
if (p -> type != KN_PATTERN_ALLGATHERX ) {
290
300
mmap_params .address = task -> allgather_kn .sbuf ;
@@ -310,13 +320,10 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
310
320
goto out ;
311
321
}
312
322
while (!ucc_knomial_pattern_loop_done (p )) {
313
- printf ("in the while loop" );
314
323
ucc_kn_ag_pattern_peer_seg (rank , p , & local_seg_count ,
315
324
& local_seg_offset );
316
325
sbuf = PTR_OFFSET (rbuf , local_seg_offset * dt_size );
317
-
318
326
for (loop_step = radix - 1 ; loop_step > 0 ; loop_step -- ) {
319
- printf ("in the for loop" );
320
327
peer = ucc_knomial_pattern_get_loop_peer (p , rank , loop_step );
321
328
if (peer == UCC_KN_PEER_NULL )
322
329
continue ;
@@ -329,7 +336,6 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
329
336
}
330
337
mmap_params .address = sbuf ;
331
338
mmap_params .length = local_seg_count * dt_size ;
332
- printf ("register memory : count_mh: %d, mh: %lx\n" , count_mh , (unsigned long )mh_list [count_mh ]);
333
339
MEM_MAP ();
334
340
}
335
341
@@ -370,12 +376,23 @@ ucc_status_t register_memory(ucc_coll_task_t *coll_task){
370
376
ucc_status_t ucc_tl_ucp_allgather_knomial_finalize (ucc_coll_task_t * coll_task ){
371
377
ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task ,
372
378
ucc_tl_ucp_task_t );
379
+ ucc_status_t status ;
380
+ ucc_tl_ucp_team_t * team = TASK_TEAM (task );
381
+ ucc_tl_ucp_context_t * ctx = UCC_TL_UCP_TEAM_CTX (team );
373
382
374
383
ucc_mpool_cleanup (& task -> allgather_kn .etask_node_mpool , 1 );
384
+ for (int i = 0 ; i < task -> allgather_kn .max_mh + 1 ; i ++ ){
385
+ ucp_mem_unmap (ctx -> worker .ucp_context , task -> allgather_kn .mh_list [i ]);
386
+ }
375
387
free (task -> allgather_kn .mh_list );
388
+ status = ucc_tl_ucp_coll_finalize (& task -> super );
389
+ if (status < 0 ){
390
+ tl_error (UCC_TASK_LIB (task ),
391
+ "failed to initialize ucc_mpool" );
392
+ }
376
393
377
394
return UCC_OK ;
378
- };
395
+ }
379
396
380
397
ucc_status_t ucc_tl_ucp_allgather_knomial_init_r (
381
398
ucc_base_coll_args_t * coll_args , ucc_base_team_t * team ,
@@ -401,17 +418,17 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
401
418
task -> subset .myrank = sbgp -> group_rank ;
402
419
task -> subset .map = sbgp -> map ;
403
420
}
404
- status = register_memory (& task -> super );
405
- if (status < 0 ){
406
- tl_error (UCC_TASK_LIB (task ),
407
- "failed to register memory" );
408
- }
409
421
task -> allgather_kn .etask_linked_list_head = NULL ;
410
422
task -> allgather_kn .p .radix = radix ;
411
423
task -> super .flags |= UCC_COLL_TASK_FLAG_EXECUTOR ;
412
424
task -> super .post = ucc_tl_ucp_allgather_knomial_start ;
413
425
task -> super .progress = ucc_tl_ucp_allgather_knomial_progress ;
414
426
task -> super .finalize = ucc_tl_ucp_allgather_knomial_finalize ;
427
+ status = register_memory (& task -> super );
428
+ if (status < 0 ){
429
+ tl_error (UCC_TASK_LIB (task ),
430
+ "failed to register memory" );
431
+ }
415
432
* task_h = & task -> super ;
416
433
return UCC_OK ;
417
434
}
0 commit comments