|
| 1 | +#include "config.h" |
| 2 | +#include "tl_ucp.h" |
| 3 | +#include "allreduce.h" |
| 4 | +#include "core/ucc_progress_queue.h" |
| 5 | +#include "tl_ucp_sendrecv.h" |
| 6 | +#include "utils/ucc_math.h" |
| 7 | +#include "utils/ucc_coll_utils.h" |
| 8 | +#include "components/mc/ucc_mc.h" |
| 9 | +#include "utils/ucc_dt_reduce.h" |
| 10 | +#include "components/ec/ucc_ec.h" |
| 11 | + |
| 12 | +void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task) |
| 13 | +{ |
| 14 | + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); |
| 15 | + ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
| 16 | + ucc_rank_t trank = task->subset.myrank; |
| 17 | + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; |
| 18 | + void *sbuf = TASK_ARGS(task).src.info.buffer; |
| 19 | + void *rbuf = TASK_ARGS(task).dst.info.buffer; |
| 20 | + ucc_memory_type_t mem_type = TASK_ARGS(task).dst.info.mem_type; |
| 21 | + size_t count = TASK_ARGS(task).dst.info.count; |
| 22 | + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; |
| 23 | + size_t dt_size = ucc_dt_size(dt); |
| 24 | + size_t data_size = count * dt_size; |
| 25 | + |
| 26 | + /* Early return for zero-count or single-rank edge cases */ |
| 27 | + if (data_size == 0 || tsize <= 1) { |
| 28 | + /* If not in-place, we need to copy sbuf to rbuf */ |
| 29 | + if (!UCC_IS_INPLACE(TASK_ARGS(task)) && data_size > 0) { |
| 30 | + memcpy(rbuf, sbuf, data_size); |
| 31 | + } |
| 32 | + task->super.status = UCC_OK; |
| 33 | + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_done", 0); |
| 34 | + return; |
| 35 | + } |
| 36 | + |
| 37 | + size_t chunk_size, send_offset, recv_offset; |
| 38 | + size_t send_chunk_size, recv_chunk_size; |
| 39 | + ucc_rank_t sendto, recvfrom; |
| 40 | + void *recv_buf, *send_buf, *reduce_buf; |
| 41 | + ucc_status_t status; |
| 42 | + int step; |
| 43 | + int num_chunks = tsize; |
| 44 | + int send_chunk, recv_chunk; |
| 45 | + enum { |
| 46 | + RING_PHASE_INIT, /* Initialize step */ |
| 47 | + RING_PHASE_SEND_RECV, /* Send/receive phase */ |
| 48 | + RING_PHASE_REDUCE, /* Reduction phase */ |
| 49 | + RING_PHASE_COMPLETE /* Step is complete, advance to next step */ |
| 50 | + } phase; |
| 51 | + |
| 52 | + /* Divide data into chunks, ensuring chunk size is aligned to datatype size */ |
| 53 | + chunk_size = ucc_div_round_up(data_size, num_chunks); |
| 54 | + chunk_size = ((chunk_size + dt_size - 1) / dt_size) * dt_size; |
| 55 | + |
| 56 | + if (UCC_IS_INPLACE(TASK_ARGS(task))) { |
| 57 | + sbuf = rbuf; |
| 58 | + } |
| 59 | + |
| 60 | + sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); |
| 61 | + recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); |
| 62 | + |
| 63 | + /* Single-phase Ring Algorithm (SRA): |
| 64 | + * - Each rank starts with its local chunk fully reduced |
| 65 | + * - In each step, ranks exchange chunks and combine them |
| 66 | + * - After tsize-1 steps, all ranks have the complete reduced result |
| 67 | + */ |
| 68 | + |
| 69 | + /* On first entry, initialize step and phase */ |
| 70 | + if (task->allreduce_ring.step == 0 && task->allreduce_ring.phase == RING_PHASE_INIT) { |
| 71 | + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { |
| 72 | + memcpy(rbuf, sbuf, data_size); |
| 73 | + } |
| 74 | + task->allreduce_ring.phase = RING_PHASE_SEND_RECV; |
| 75 | + } |
| 76 | + |
| 77 | + /* Process steps: 0 to tsize-2 (standard SRA uses tsize-1 steps) */ |
| 78 | + while (task->allreduce_ring.step < tsize - 1) { |
| 79 | + step = task->allreduce_ring.step; |
| 80 | + phase = task->allreduce_ring.phase; |
| 81 | + |
| 82 | + /* Check if we have a pending reduction task and test for completion */ |
| 83 | + if (phase == RING_PHASE_REDUCE && task->allreduce_ring.etask != NULL) { |
| 84 | + status = ucc_ee_executor_task_test(task->allreduce_ring.etask); |
| 85 | + |
| 86 | + if (status == UCC_INPROGRESS) { |
| 87 | + /* Return and try again later */ |
| 88 | + return; |
| 89 | + } |
| 90 | + |
| 91 | + if (ucc_unlikely(status != UCC_OK)) { |
| 92 | + tl_error(UCC_TASK_LIB(task), "reduction task failed: %s", |
| 93 | + ucc_status_string(status)); |
| 94 | + task->super.status = status; |
| 95 | + return; |
| 96 | + } |
| 97 | + |
| 98 | + ucc_ee_executor_task_finalize(task->allreduce_ring.etask); |
| 99 | + task->allreduce_ring.etask = NULL; |
| 100 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 101 | + } |
| 102 | + |
| 103 | + /* If we've completed the current step, advance to the next one */ |
| 104 | + if (phase == RING_PHASE_COMPLETE) { |
| 105 | + task->allreduce_ring.step++; |
| 106 | + task->allreduce_ring.phase = RING_PHASE_SEND_RECV; |
| 107 | + |
| 108 | + if (task->allreduce_ring.step >= tsize - 1) { |
| 109 | + break; |
| 110 | + } |
| 111 | + |
| 112 | + step = task->allreduce_ring.step; |
| 113 | + } |
| 114 | + |
| 115 | + /* Send/receive phase */ |
| 116 | + if (phase == RING_PHASE_SEND_RECV) { |
| 117 | + send_chunk = (trank - step + tsize) % tsize; |
| 118 | + recv_chunk = (trank - step - 1 + tsize) % tsize; |
| 119 | + |
| 120 | + /* Calculate send offset and chunk size */ |
| 121 | + send_offset = send_chunk * chunk_size; |
| 122 | + if (send_offset >= data_size) { |
| 123 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 124 | + continue; |
| 125 | + } |
| 126 | + |
| 127 | + /* Calculate actual size of this chunk */ |
| 128 | + send_chunk_size = data_size - send_offset; |
| 129 | + if (send_chunk_size > chunk_size) { |
| 130 | + send_chunk_size = chunk_size; |
| 131 | + } |
| 132 | + send_chunk_size = (send_chunk_size / dt_size) * dt_size; |
| 133 | + |
| 134 | + if (send_chunk_size == 0) { |
| 135 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 136 | + continue; |
| 137 | + } |
| 138 | + |
| 139 | + /* Calculate receive offset and chunk size */ |
| 140 | + recv_offset = recv_chunk * chunk_size; |
| 141 | + if (recv_offset >= data_size) { |
| 142 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 143 | + continue; |
| 144 | + } |
| 145 | + |
| 146 | + recv_chunk_size = data_size - recv_offset; |
| 147 | + if (recv_chunk_size > chunk_size) { |
| 148 | + recv_chunk_size = chunk_size; |
| 149 | + } |
| 150 | + recv_chunk_size = (recv_chunk_size / dt_size) * dt_size; |
| 151 | + |
| 152 | + if (recv_chunk_size == 0) { |
| 153 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 154 | + continue; |
| 155 | + } |
| 156 | + |
| 157 | + /* Send and receive chunks */ |
| 158 | + send_buf = PTR_OFFSET(rbuf, send_offset); |
| 159 | + recv_buf = PTR_OFFSET(task->allreduce_ring.scratch, 0); |
| 160 | + |
| 161 | + UCPCHECK_GOTO( |
| 162 | + ucc_tl_ucp_send_nb(send_buf, send_chunk_size, mem_type, sendto, team, task), |
| 163 | + task, out); |
| 164 | + |
| 165 | + UCPCHECK_GOTO( |
| 166 | + ucc_tl_ucp_recv_nb(recv_buf, recv_chunk_size, mem_type, recvfrom, team, task), |
| 167 | + task, out); |
| 168 | + |
| 169 | + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { |
| 170 | + return; |
| 171 | + } |
| 172 | + |
| 173 | + /* Save chunk information for the reduction phase */ |
| 174 | + task->allreduce_ring.phase = RING_PHASE_REDUCE; |
| 175 | + task->allreduce_ring.recv_offset = recv_offset; |
| 176 | + task->allreduce_ring.recv_size = recv_chunk_size; |
| 177 | + } |
| 178 | + |
| 179 | + if (phase == RING_PHASE_REDUCE) { |
| 180 | + recv_offset = task->allreduce_ring.recv_offset; |
| 181 | + recv_chunk_size = task->allreduce_ring.recv_size; |
| 182 | + |
| 183 | + recv_buf = PTR_OFFSET(task->allreduce_ring.scratch, 0); |
| 184 | + reduce_buf = PTR_OFFSET(rbuf, recv_offset); |
| 185 | + |
| 186 | + status = ucc_dt_reduce(reduce_buf, recv_buf, reduce_buf, |
| 187 | + recv_chunk_size / dt_size, |
| 188 | + dt, &TASK_ARGS(task), 0, 0, |
| 189 | + task->allreduce_ring.executor, |
| 190 | + &task->allreduce_ring.etask); |
| 191 | + |
| 192 | + if (ucc_unlikely(status != UCC_OK)) { |
| 193 | + tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction"); |
| 194 | + task->super.status = status; |
| 195 | + return; |
| 196 | + } |
| 197 | + |
| 198 | + if (task->allreduce_ring.etask != NULL) { |
| 199 | + return; |
| 200 | + } |
| 201 | + |
| 202 | + task->allreduce_ring.phase = RING_PHASE_COMPLETE; |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); |
| 207 | + task->super.status = UCC_OK; |
| 208 | +out: |
| 209 | + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_done", 0); |
| 210 | +} |
| 211 | + |
| 212 | +ucc_status_t ucc_tl_ucp_allreduce_ring_start(ucc_coll_task_t *coll_task) |
| 213 | +{ |
| 214 | + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); |
| 215 | + ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
| 216 | + |
| 217 | + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_start", 0); |
| 218 | + ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); |
| 219 | + |
| 220 | + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); |
| 221 | +} |
| 222 | + |
| 223 | +ucc_status_t ucc_tl_ucp_allreduce_ring_init_common(ucc_tl_ucp_task_t *task) |
| 224 | +{ |
| 225 | + ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
| 226 | + ucc_sbgp_t *sbgp; |
| 227 | + size_t count = TASK_ARGS(task).dst.info.count; |
| 228 | + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; |
| 229 | + size_t dt_size = ucc_dt_size(dt); |
| 230 | + size_t data_size = count * dt_size; |
| 231 | + size_t chunk_size; |
| 232 | + ucc_status_t status; |
| 233 | + |
| 234 | + if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { |
| 235 | + tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); |
| 236 | + return UCC_ERR_NOT_SUPPORTED; |
| 237 | + } |
| 238 | + |
| 239 | + if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET) && team->cfg.use_reordering) { |
| 240 | + sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED); |
| 241 | + task->subset.myrank = sbgp->group_rank; |
| 242 | + task->subset.map = sbgp->map; |
| 243 | + } |
| 244 | + |
| 245 | + /* Calculate chunk size for a single chunk */ |
| 246 | + chunk_size = ucc_div_round_up(data_size, task->subset.map.ep_num); |
| 247 | + chunk_size = ((chunk_size + dt_size - 1) / dt_size) * dt_size; |
| 248 | + |
| 249 | + /* Allocate scratch space for a single chunk */ |
| 250 | + status = ucc_mc_alloc(&task->allreduce_ring.scratch_mc_header, |
| 251 | + chunk_size, TASK_ARGS(task).dst.info.mem_type); |
| 252 | + if (ucc_unlikely(status != UCC_OK)) { |
| 253 | + tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer"); |
| 254 | + return status; |
| 255 | + } |
| 256 | + task->allreduce_ring.scratch = task->allreduce_ring.scratch_mc_header->addr; |
| 257 | + |
| 258 | + task->allreduce_ring.step = 0; |
| 259 | + task->allreduce_ring.phase = 0; |
| 260 | + task->allreduce_ring.etask = NULL; |
| 261 | + |
| 262 | + ucc_ee_executor_params_t eparams = {0}; |
| 263 | + eparams.mask = UCC_EE_EXECUTOR_PARAM_FIELD_TYPE; |
| 264 | + eparams.ee_type = UCC_EE_CPU_THREAD; |
| 265 | + status = ucc_ee_executor_init(&eparams, &task->allreduce_ring.executor); |
| 266 | + if (ucc_unlikely(status != UCC_OK)) { |
| 267 | + tl_error(UCC_TASK_LIB(task), "failed to initialize executor"); |
| 268 | + ucc_mc_free(task->allreduce_ring.scratch_mc_header); |
| 269 | + return status; |
| 270 | + } |
| 271 | + |
| 272 | + task->super.post = ucc_tl_ucp_allreduce_ring_start; |
| 273 | + task->super.progress = ucc_tl_ucp_allreduce_ring_progress; |
| 274 | + task->super.finalize = ucc_tl_ucp_allreduce_ring_finalize; |
| 275 | + |
| 276 | + return UCC_OK; |
| 277 | +} |
| 278 | + |
| 279 | +ucc_status_t ucc_tl_ucp_allreduce_ring_init(ucc_base_coll_args_t *coll_args, |
| 280 | + ucc_base_team_t * team, |
| 281 | + ucc_coll_task_t ** task_h) |
| 282 | +{ |
| 283 | + ucc_tl_ucp_task_t *task; |
| 284 | + ucc_status_t status; |
| 285 | + |
| 286 | + task = ucc_tl_ucp_init_task(coll_args, team); |
| 287 | + status = ucc_tl_ucp_allreduce_ring_init_common(task); |
| 288 | + if (status != UCC_OK) { |
| 289 | + ucc_tl_ucp_put_task(task); |
| 290 | + return status; |
| 291 | + } |
| 292 | + *task_h = &task->super; |
| 293 | + return UCC_OK; |
| 294 | +} |
| 295 | + |
| 296 | +ucc_status_t ucc_tl_ucp_allreduce_ring_finalize(ucc_coll_task_t *coll_task) |
| 297 | +{ |
| 298 | + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); |
| 299 | + ucc_status_t st, global_st = UCC_OK; |
| 300 | + |
| 301 | + if (task->allreduce_ring.etask != NULL) { |
| 302 | + ucc_ee_executor_task_finalize(task->allreduce_ring.etask); |
| 303 | + task->allreduce_ring.etask = NULL; |
| 304 | + } |
| 305 | + |
| 306 | + if (task->allreduce_ring.executor != NULL) { |
| 307 | + st = ucc_ee_executor_finalize(task->allreduce_ring.executor); |
| 308 | + if (ucc_unlikely(st != UCC_OK)) { |
| 309 | + tl_error(UCC_TASK_LIB(task), "failed to finalize executor"); |
| 310 | + global_st = st; |
| 311 | + } |
| 312 | + task->allreduce_ring.executor = NULL; |
| 313 | + } |
| 314 | + |
| 315 | + st = ucc_mc_free(task->allreduce_ring.scratch_mc_header); |
| 316 | + if (ucc_unlikely(st != UCC_OK)) { |
| 317 | + tl_error(UCC_TASK_LIB(task), "failed to free scratch buffer"); |
| 318 | + global_st = (global_st == UCC_OK) ? st : global_st; |
| 319 | + } |
| 320 | + |
| 321 | + st = ucc_tl_ucp_coll_finalize(&task->super); |
| 322 | + if (ucc_unlikely(st != UCC_OK)) { |
| 323 | + tl_error(UCC_TASK_LIB(task), "failed finalize collective"); |
| 324 | + global_st = (global_st == UCC_OK) ? st : global_st; |
| 325 | + } |
| 326 | + return global_st; |
| 327 | +} |
0 commit comments