@@ -40,6 +40,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
40
40
size_t count = TASK_ARGS (task ).dst .info .count ;
41
41
ucc_datatype_t dt = TASK_ARGS (task ).dst .info .datatype ;
42
42
size_t data_size = (count / tsize ) * ucc_dt_size (dt );
43
+ ucc_status_t status = UCC_OK ;
43
44
ucc_rank_t sendto , recvfrom , sblock , rblock ;
44
45
int step ;
45
46
void * buf ;
@@ -69,7 +70,14 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
69
70
}
70
71
}
71
72
ucc_assert (UCC_TL_UCP_TASK_P2P_COMPLETE (task ));
72
- task -> super .status = UCC_OK ;
73
+ if (task -> allgather_ring .etask ) {
74
+ status = ucc_ee_executor_task_test (task -> allgather_ring .etask );
75
+ if (status == UCC_INPROGRESS ) {
76
+ return ;
77
+ }
78
+ ucc_ee_executor_task_finalize (task -> allgather_ring .etask );
79
+ }
80
+ task -> super .status = status ;
73
81
out :
74
82
UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_ring_done" , 0 );
75
83
}
@@ -88,22 +96,50 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
88
96
ucc_rank_t tsize = (ucc_rank_t )task -> subset .map .ep_num ;
89
97
size_t data_size = (count / tsize ) * ucc_dt_size (dt );
90
98
ucc_status_t status ;
91
- ucc_rank_t block ;
99
+ ucc_rank_t sendto , recvfrom , sblock , rblock ;
100
+ ucc_ee_executor_t * exec ;
101
+ ucc_ee_executor_task_args_t eargs ;
102
+ void * buf ;
92
103
93
104
UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allgather_ring_start" , 0 );
94
105
ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
95
106
107
+ sendto = ucc_ep_map_eval (task -> subset .map , (trank + 1 ) % tsize );
108
+ recvfrom = ucc_ep_map_eval (task -> subset .map , (trank - 1 + tsize ) % tsize );
109
+ sblock = task -> allgather_ring .get_send_block (& task -> subset , trank , tsize , 0 );
110
+ rblock = task -> allgather_ring .get_recv_block (& task -> subset , trank , tsize , 0 );
96
111
if (!UCC_IS_INPLACE (TASK_ARGS (task ))) {
97
- block = task -> allgather_ring .get_send_block (& task -> subset , trank , tsize ,
98
- 0 );
99
- status = ucc_mc_memcpy (PTR_OFFSET (rbuf , data_size * block ),
100
- sbuf , data_size , rmem , smem );
101
- if (ucc_unlikely (UCC_OK != status )) {
112
+ status = ucc_coll_task_get_executor (& task -> super , & exec );
113
+ if (ucc_unlikely (status != UCC_OK )) {
102
114
return status ;
103
115
}
116
+
117
+ eargs .task_type = UCC_EE_EXECUTOR_TASK_COPY ;
118
+ eargs .copy .src = sbuf ;
119
+ eargs .copy .dst = PTR_OFFSET (rbuf , data_size * sblock );
120
+ eargs .copy .len = data_size ;
121
+
122
+ status = ucc_ee_executor_task_post (exec , & eargs ,
123
+ & task -> allgather_ring .etask );
124
+ if (ucc_unlikely (status != UCC_OK )) {
125
+ return status ;
126
+ }
127
+ buf = sbuf ;
128
+ } else {
129
+ task -> allgather_ring .etask = NULL ;
130
+ buf = PTR_OFFSET (rbuf , data_size * sblock );
104
131
}
105
132
133
+ UCPCHECK_GOTO (ucc_tl_ucp_send_nb (buf , data_size , smem , sendto , team , task ),
134
+ task , out );
135
+ UCPCHECK_GOTO (ucc_tl_ucp_recv_nb (PTR_OFFSET (rbuf , rblock * data_size ),
136
+ data_size , rmem , recvfrom , team , task ),
137
+ task , out );
138
+
106
139
return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
140
+
141
+ out :
142
+ return status ;
107
143
}
108
144
109
145
ucc_status_t ucc_tl_ucp_allgather_ring_init_common (ucc_tl_ucp_task_t * task )
@@ -128,6 +164,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
128
164
task -> allgather_ring .get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block ;
129
165
task -> super .post = ucc_tl_ucp_allgather_ring_start ;
130
166
task -> super .progress = ucc_tl_ucp_allgather_ring_progress ;
167
+ if (!UCC_IS_INPLACE (TASK_ARGS (task ))) {
168
+ task -> super .flags |= UCC_COLL_TASK_FLAG_EXECUTOR ;
169
+ }
131
170
132
171
return UCC_OK ;
133
172
}
0 commit comments