Skip to content

Commit cc33fd5

Browse files
committed
TL/UCP: Add all-reduce ring alogrithm
Signed-off-by: Armen Ratner <[email protected]>
1 parent fe0773f commit cc33fd5

File tree

5 files changed

+354
-1
lines changed

5 files changed

+354
-1
lines changed

src/components/tl/ucp/Makefile.am

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ allreduce = \
4646
allreduce/allreduce_sliding_window.h \
4747
allreduce/allreduce_sliding_window.c \
4848
allreduce/allreduce_sliding_window_setup.c \
49-
allreduce/allreduce_dbt.c
49+
allreduce/allreduce_dbt.c \
50+
allreduce/allreduce_ring.c
5051

5152
barrier = \
5253
barrier/barrier.h \

src/components/tl/ucp/allreduce/allreduce.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ ucc_base_coll_alg_info_t
2929
{.id = UCC_TL_UCP_ALLREDUCE_ALG_SLIDING_WINDOW,
3030
.name = "sliding_window",
3131
.desc = "sliding window allreduce (optimized for running on DPU)"},
32+
[UCC_TL_UCP_ALLREDUCE_ALG_RING] =
33+
{.id = UCC_TL_UCP_ALLREDUCE_ALG_RING,
34+
.name = "ring",
35+
.desc = "ring-based allreduce (optimized for BW and simple topologies)"},
3236
[UCC_TL_UCP_ALLREDUCE_ALG_LAST] = {
3337
.id = 0, .name = NULL, .desc = NULL}};
3438

src/components/tl/ucp/allreduce/allreduce.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ enum {
1313
UCC_TL_UCP_ALLREDUCE_ALG_SRA_KNOMIAL,
1414
UCC_TL_UCP_ALLREDUCE_ALG_SLIDING_WINDOW,
1515
UCC_TL_UCP_ALLREDUCE_ALG_DBT,
16+
UCC_TL_UCP_ALLREDUCE_ALG_RING,
1617
UCC_TL_UCP_ALLREDUCE_ALG_LAST
1718
};
1819

@@ -77,6 +78,16 @@ ucc_status_t ucc_tl_ucp_allreduce_dbt_start(ucc_coll_task_t *task);
7778

7879
ucc_status_t ucc_tl_ucp_allreduce_dbt_progress(ucc_coll_task_t *task);
7980

81+
void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task);
82+
83+
ucc_status_t ucc_tl_ucp_allreduce_ring_start(ucc_coll_task_t *coll_task);
84+
85+
ucc_status_t ucc_tl_ucp_allreduce_ring_init(ucc_base_coll_args_t *coll_args,
86+
ucc_base_team_t *team,
87+
ucc_coll_task_t **task_h);
88+
89+
ucc_status_t ucc_tl_ucp_allreduce_ring_finalize(ucc_coll_task_t *coll_task);
90+
8091
static inline int ucc_tl_ucp_allreduce_alg_from_str(const char *str)
8192
{
8293
int i;
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
#include "config.h"
2+
#include "tl_ucp.h"
3+
#include "allreduce.h"
4+
#include "core/ucc_progress_queue.h"
5+
#include "tl_ucp_sendrecv.h"
6+
#include "utils/ucc_math.h"
7+
#include "utils/ucc_coll_utils.h"
8+
#include "components/mc/ucc_mc.h"
9+
#include "utils/ucc_dt_reduce.h"
10+
#include "components/ec/ucc_ec.h"
11+
12+
void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
13+
{
14+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
15+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
16+
ucc_rank_t trank = task->subset.myrank;
17+
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
18+
void *sbuf = TASK_ARGS(task).src.info.buffer;
19+
void *rbuf = TASK_ARGS(task).dst.info.buffer;
20+
ucc_memory_type_t mem_type = TASK_ARGS(task).dst.info.mem_type;
21+
size_t count = TASK_ARGS(task).dst.info.count;
22+
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
23+
size_t dt_size = ucc_dt_size(dt);
24+
size_t data_size = count * dt_size;
25+
26+
/* Early return for zero-count or single-rank edge cases */
27+
if (data_size == 0 || tsize <= 1) {
28+
/* If not in-place, we need to copy sbuf to rbuf */
29+
if (!UCC_IS_INPLACE(TASK_ARGS(task)) && data_size > 0) {
30+
memcpy(rbuf, sbuf, data_size);
31+
}
32+
task->super.status = UCC_OK;
33+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_done", 0);
34+
return;
35+
}
36+
37+
size_t chunk_size, send_offset, recv_offset;
38+
size_t send_chunk_size, recv_chunk_size;
39+
ucc_rank_t sendto, recvfrom;
40+
void *recv_buf, *send_buf, *reduce_buf;
41+
ucc_status_t status;
42+
int step;
43+
int num_chunks = tsize;
44+
int send_chunk, recv_chunk;
45+
enum {
46+
RING_PHASE_INIT, /* Initialize step */
47+
RING_PHASE_SEND_RECV, /* Send/receive phase */
48+
RING_PHASE_REDUCE, /* Reduction phase */
49+
RING_PHASE_COMPLETE /* Step is complete, advance to next step */
50+
} phase;
51+
52+
/* Divide data into chunks, ensuring chunk size is aligned to datatype size */
53+
chunk_size = ucc_div_round_up(data_size, num_chunks);
54+
chunk_size = ((chunk_size + dt_size - 1) / dt_size) * dt_size;
55+
56+
if (UCC_IS_INPLACE(TASK_ARGS(task))) {
57+
sbuf = rbuf;
58+
}
59+
60+
sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
61+
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
62+
63+
/* Single-phase Ring Algorithm (SRA):
64+
* - Each rank starts with its local chunk fully reduced
65+
* - In each step, ranks exchange chunks and combine them
66+
* - After tsize-1 steps, all ranks have the complete reduced result
67+
*/
68+
69+
/* On first entry, initialize step and phase */
70+
if (task->allreduce_ring.step == 0 && task->allreduce_ring.phase == RING_PHASE_INIT) {
71+
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
72+
memcpy(rbuf, sbuf, data_size);
73+
}
74+
task->allreduce_ring.phase = RING_PHASE_SEND_RECV;
75+
}
76+
77+
/* Process steps: 0 to tsize-2 (standard SRA uses tsize-1 steps) */
78+
while (task->allreduce_ring.step < tsize - 1) {
79+
step = task->allreduce_ring.step;
80+
phase = task->allreduce_ring.phase;
81+
82+
/* Check if we have a pending reduction task and test for completion */
83+
if (phase == RING_PHASE_REDUCE && task->allreduce_ring.etask != NULL) {
84+
status = ucc_ee_executor_task_test(task->allreduce_ring.etask);
85+
86+
if (status == UCC_INPROGRESS) {
87+
/* Return and try again later */
88+
return;
89+
}
90+
91+
if (ucc_unlikely(status != UCC_OK)) {
92+
tl_error(UCC_TASK_LIB(task), "reduction task failed: %s",
93+
ucc_status_string(status));
94+
task->super.status = status;
95+
return;
96+
}
97+
98+
ucc_ee_executor_task_finalize(task->allreduce_ring.etask);
99+
task->allreduce_ring.etask = NULL;
100+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
101+
}
102+
103+
/* If we've completed the current step, advance to the next one */
104+
if (phase == RING_PHASE_COMPLETE) {
105+
task->allreduce_ring.step++;
106+
task->allreduce_ring.phase = RING_PHASE_SEND_RECV;
107+
108+
if (task->allreduce_ring.step >= tsize - 1) {
109+
break;
110+
}
111+
112+
step = task->allreduce_ring.step;
113+
}
114+
115+
/* Send/receive phase */
116+
if (phase == RING_PHASE_SEND_RECV) {
117+
send_chunk = (trank - step + tsize) % tsize;
118+
recv_chunk = (trank - step - 1 + tsize) % tsize;
119+
120+
/* Calculate send offset and chunk size */
121+
send_offset = send_chunk * chunk_size;
122+
if (send_offset >= data_size) {
123+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
124+
continue;
125+
}
126+
127+
/* Calculate actual size of this chunk */
128+
send_chunk_size = data_size - send_offset;
129+
if (send_chunk_size > chunk_size) {
130+
send_chunk_size = chunk_size;
131+
}
132+
send_chunk_size = (send_chunk_size / dt_size) * dt_size;
133+
134+
if (send_chunk_size == 0) {
135+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
136+
continue;
137+
}
138+
139+
/* Calculate receive offset and chunk size */
140+
recv_offset = recv_chunk * chunk_size;
141+
if (recv_offset >= data_size) {
142+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
143+
continue;
144+
}
145+
146+
recv_chunk_size = data_size - recv_offset;
147+
if (recv_chunk_size > chunk_size) {
148+
recv_chunk_size = chunk_size;
149+
}
150+
recv_chunk_size = (recv_chunk_size / dt_size) * dt_size;
151+
152+
if (recv_chunk_size == 0) {
153+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
154+
continue;
155+
}
156+
157+
/* Send and receive chunks */
158+
send_buf = PTR_OFFSET(rbuf, send_offset);
159+
recv_buf = PTR_OFFSET(task->allreduce_ring.scratch, 0);
160+
161+
UCPCHECK_GOTO(
162+
ucc_tl_ucp_send_nb(send_buf, send_chunk_size, mem_type, sendto, team, task),
163+
task, out);
164+
165+
UCPCHECK_GOTO(
166+
ucc_tl_ucp_recv_nb(recv_buf, recv_chunk_size, mem_type, recvfrom, team, task),
167+
task, out);
168+
169+
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
170+
return;
171+
}
172+
173+
/* Save chunk information for the reduction phase */
174+
task->allreduce_ring.phase = RING_PHASE_REDUCE;
175+
task->allreduce_ring.recv_offset = recv_offset;
176+
task->allreduce_ring.recv_size = recv_chunk_size;
177+
}
178+
179+
if (phase == RING_PHASE_REDUCE) {
180+
recv_offset = task->allreduce_ring.recv_offset;
181+
recv_chunk_size = task->allreduce_ring.recv_size;
182+
183+
recv_buf = PTR_OFFSET(task->allreduce_ring.scratch, 0);
184+
reduce_buf = PTR_OFFSET(rbuf, recv_offset);
185+
186+
status = ucc_dt_reduce(reduce_buf, recv_buf, reduce_buf,
187+
recv_chunk_size / dt_size,
188+
dt, &TASK_ARGS(task), 0, 0,
189+
task->allreduce_ring.executor,
190+
&task->allreduce_ring.etask);
191+
192+
if (ucc_unlikely(status != UCC_OK)) {
193+
tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction");
194+
task->super.status = status;
195+
return;
196+
}
197+
198+
if (task->allreduce_ring.etask != NULL) {
199+
return;
200+
}
201+
202+
task->allreduce_ring.phase = RING_PHASE_COMPLETE;
203+
}
204+
}
205+
206+
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
207+
task->super.status = UCC_OK;
208+
out:
209+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_done", 0);
210+
}
211+
212+
ucc_status_t ucc_tl_ucp_allreduce_ring_start(ucc_coll_task_t *coll_task)
213+
{
214+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
215+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
216+
217+
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_start", 0);
218+
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
219+
220+
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
221+
}
222+
223+
ucc_status_t ucc_tl_ucp_allreduce_ring_init_common(ucc_tl_ucp_task_t *task)
224+
{
225+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
226+
ucc_sbgp_t *sbgp;
227+
size_t count = TASK_ARGS(task).dst.info.count;
228+
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
229+
size_t dt_size = ucc_dt_size(dt);
230+
size_t data_size = count * dt_size;
231+
size_t chunk_size;
232+
ucc_status_t status;
233+
234+
if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
235+
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
236+
return UCC_ERR_NOT_SUPPORTED;
237+
}
238+
239+
if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET) && team->cfg.use_reordering) {
240+
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
241+
task->subset.myrank = sbgp->group_rank;
242+
task->subset.map = sbgp->map;
243+
}
244+
245+
/* Calculate chunk size for a single chunk */
246+
chunk_size = ucc_div_round_up(data_size, task->subset.map.ep_num);
247+
chunk_size = ((chunk_size + dt_size - 1) / dt_size) * dt_size;
248+
249+
/* Allocate scratch space for a single chunk */
250+
status = ucc_mc_alloc(&task->allreduce_ring.scratch_mc_header,
251+
chunk_size, TASK_ARGS(task).dst.info.mem_type);
252+
if (ucc_unlikely(status != UCC_OK)) {
253+
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
254+
return status;
255+
}
256+
task->allreduce_ring.scratch = task->allreduce_ring.scratch_mc_header->addr;
257+
258+
task->allreduce_ring.step = 0;
259+
task->allreduce_ring.phase = 0;
260+
task->allreduce_ring.etask = NULL;
261+
262+
ucc_ee_executor_params_t eparams = {0};
263+
eparams.mask = UCC_EE_EXECUTOR_PARAM_FIELD_TYPE;
264+
eparams.ee_type = UCC_EE_CPU_THREAD;
265+
status = ucc_ee_executor_init(&eparams, &task->allreduce_ring.executor);
266+
if (ucc_unlikely(status != UCC_OK)) {
267+
tl_error(UCC_TASK_LIB(task), "failed to initialize executor");
268+
ucc_mc_free(task->allreduce_ring.scratch_mc_header);
269+
return status;
270+
}
271+
272+
task->super.post = ucc_tl_ucp_allreduce_ring_start;
273+
task->super.progress = ucc_tl_ucp_allreduce_ring_progress;
274+
task->super.finalize = ucc_tl_ucp_allreduce_ring_finalize;
275+
276+
return UCC_OK;
277+
}
278+
279+
ucc_status_t ucc_tl_ucp_allreduce_ring_init(ucc_base_coll_args_t *coll_args,
280+
ucc_base_team_t * team,
281+
ucc_coll_task_t ** task_h)
282+
{
283+
ucc_tl_ucp_task_t *task;
284+
ucc_status_t status;
285+
286+
task = ucc_tl_ucp_init_task(coll_args, team);
287+
status = ucc_tl_ucp_allreduce_ring_init_common(task);
288+
if (status != UCC_OK) {
289+
ucc_tl_ucp_put_task(task);
290+
return status;
291+
}
292+
*task_h = &task->super;
293+
return UCC_OK;
294+
}
295+
296+
ucc_status_t ucc_tl_ucp_allreduce_ring_finalize(ucc_coll_task_t *coll_task)
297+
{
298+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
299+
ucc_status_t st, global_st = UCC_OK;
300+
301+
if (task->allreduce_ring.etask != NULL) {
302+
ucc_ee_executor_task_finalize(task->allreduce_ring.etask);
303+
task->allreduce_ring.etask = NULL;
304+
}
305+
306+
if (task->allreduce_ring.executor != NULL) {
307+
st = ucc_ee_executor_finalize(task->allreduce_ring.executor);
308+
if (ucc_unlikely(st != UCC_OK)) {
309+
tl_error(UCC_TASK_LIB(task), "failed to finalize executor");
310+
global_st = st;
311+
}
312+
task->allreduce_ring.executor = NULL;
313+
}
314+
315+
st = ucc_mc_free(task->allreduce_ring.scratch_mc_header);
316+
if (ucc_unlikely(st != UCC_OK)) {
317+
tl_error(UCC_TASK_LIB(task), "failed to free scratch buffer");
318+
global_st = (global_st == UCC_OK) ? st : global_st;
319+
}
320+
321+
st = ucc_tl_ucp_coll_finalize(&task->super);
322+
if (ucc_unlikely(st != UCC_OK)) {
323+
tl_error(UCC_TASK_LIB(task), "failed finalize collective");
324+
global_st = (global_st == UCC_OK) ? st : global_st;
325+
}
326+
return global_st;
327+
}

src/components/tl/ucp/tl_ucp_coll.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,16 @@ typedef struct ucc_tl_ucp_task {
268268
ucc_rank_t iteration;
269269
int phase;
270270
} alltoall_bruck;
271+
struct {
272+
void *scratch;
273+
ucc_mc_buffer_header_t *scratch_mc_header;
274+
ucc_ee_executor_task_t *etask;
275+
ucc_ee_executor_t *executor;
276+
int step;
277+
int phase;
278+
size_t recv_offset;
279+
size_t recv_size;
280+
} allreduce_ring;
271281
char plugin_data[UCC_TL_UCP_TASK_PLUGIN_MAX_DATA];
272282
};
273283
} ucc_tl_ucp_task_t;

0 commit comments

Comments
 (0)