|
11 | 11 | #include "utils/ucc_math.h" |
12 | 12 | #include "tl_ucp_sendrecv.h" |
13 | 13 |
|
| 14 | +#define CONGESTION_THRESHOLD 8 |
| 15 | +#define GET_OP 0 |
| 16 | +#define PUT_OP 1 |
| 17 | + |
14 | 18 | void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask); |
15 | 19 |
|
16 | | -ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) |
| 20 | +ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_start(ucc_coll_task_t *ctask) |
17 | 21 | { |
18 | | - ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); |
19 | | - ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
20 | | - ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer; |
21 | | - ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer; |
22 | | - size_t nelems = TASK_ARGS(task).src.info.count; |
23 | | - ucc_rank_t grank = UCC_TL_TEAM_RANK(team); |
24 | | - ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); |
25 | | - ucc_rank_t start = (grank + 1) % gsize; |
26 | | - long *pSync = TASK_ARGS(task).global_work_buffer; |
27 | | - ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh; |
28 | | - ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh; |
29 | | - ucc_rank_t peer; |
| 22 | + return ucc_schedule_start(ctask); |
| 23 | +} |
| 24 | + |
| 25 | +ucc_status_t ucc_tl_ucp_alltoall_onesided_sched_finalize(ucc_coll_task_t *ctask) |
| 26 | +{ |
| 27 | + ucc_schedule_t *schedule = ucc_derived_of(ctask, ucc_schedule_t); |
| 28 | + ucc_status_t status; |
| 29 | + |
| 30 | + status = ucc_schedule_finalize(ctask); |
| 31 | + ucc_tl_ucp_put_schedule(schedule); |
| 32 | + return status; |
| 33 | +} |
| 34 | + |
| 35 | +ucc_status_t ucc_tl_ucp_alltoall_onesided_finalize(ucc_coll_task_t *coll_task) |
| 36 | +{ |
| 37 | + ucc_status_t status; |
| 38 | + |
| 39 | + status = ucc_tl_ucp_coll_finalize(coll_task); |
| 40 | + if (ucc_unlikely(UCC_OK != status)) { |
| 41 | + tl_error(UCC_TASK_LIB(coll_task), "failed to finalize collective"); |
| 42 | + } |
| 43 | + return status; |
| 44 | +} |
30 | 45 |
|
31 | | - if (TASK_ARGS(task).flags & UCC_COLL_ARGS_FLAG_SRC_MEMH_GLOBAL) { |
32 | | - src_memh = TASK_ARGS(task).src_memh.global_memh[grank]; |
| 46 | +void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask) |
| 47 | +{ |
| 48 | + ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); |
| 49 | + ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
| 50 | + ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer; |
| 51 | + ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer; |
| 52 | + ucc_rank_t grank = UCC_TL_TEAM_RANK(team); |
| 53 | + ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); |
| 54 | + ucc_rank_t start = (grank + 1) % gsize; |
| 55 | + ucc_rank_t peer = (start + task->alltoall_onesided.iteration) % gsize; |
| 56 | + int iteration = task->alltoall_onesided.iteration; |
| 57 | + size_t nreqs = task->alltoall_onesided.tokens; |
| 58 | + int64_t polls = 0; |
| 59 | + int64_t npolls = task->alltoall_onesided.npolls; |
| 60 | + int use_op = task->alltoall_onesided.op; |
| 61 | + ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh; |
| 62 | + ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh; |
| 63 | + uint32_t *posted; |
| 64 | + uint32_t *completed; |
| 65 | + size_t nelems; |
| 66 | + |
| 67 | + nelems = TASK_ARGS(task).src.info.count; |
| 68 | + nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); |
| 69 | + |
| 70 | + if (use_op == GET_OP) { |
| 71 | + /* these are not typos */ |
| 72 | + src_memh = TASK_ARGS(task).dst_memh.local_memh; |
| 73 | + dst_memh = TASK_ARGS(task).src_memh.global_memh; |
| 74 | + posted = &task->onesided.get_posted; |
| 75 | + completed = &task->onesided.get_completed; |
| 76 | + } else { |
| 77 | + posted = &task->onesided.put_posted; |
| 78 | + completed = &task->onesided.put_completed; |
33 | 79 | } |
| 80 | + for (; *posted < gsize; peer = (peer + 1) % gsize) { |
| 81 | + if (use_op == PUT_OP) { |
| 82 | + UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems), |
| 83 | + (void *)PTR_OFFSET(dest, grank * nelems), nelems, peer, src_memh, dst_memh, team, task), |
| 84 | + task, out); |
| 85 | + } else { |
| 86 | + UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems), |
| 87 | + (void *)PTR_OFFSET(src, grank * nelems), nelems, peer, src_memh, dst_memh, team, task), |
| 88 | + task, out); |
| 89 | + } |
| 90 | + ++iteration; |
| 91 | + |
| 92 | + if ((*posted - *completed) >= nreqs) { |
| 93 | + while (polls++ < npolls) { |
| 94 | + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); |
| 95 | + if ((*posted - *completed) < nreqs) { |
| 96 | + break; |
| 97 | + } |
| 98 | + } |
| 99 | + if (polls >= npolls) { |
| 100 | + task->alltoall_onesided.iteration = iteration; |
| 101 | + return; |
| 102 | + } |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + if (!UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) { |
| 107 | + while (polls++ < npolls) { |
| 108 | + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); |
| 109 | + if (UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) { |
| 110 | + goto complete; |
| 111 | + } |
| 112 | + } |
| 113 | + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); |
| 114 | + if (UCC_TL_UCP_TASK_ONESIDED_P2P_COMPLETE(task)) { |
| 115 | + goto complete; |
| 116 | + } |
| 117 | + return; |
| 118 | + } |
| 119 | + |
| 120 | +complete: |
| 121 | + task->super.status = UCC_OK; |
| 122 | +out: |
| 123 | + return; |
| 124 | +} |
| 125 | + |
| 126 | +ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) |
| 127 | +{ |
| 128 | + ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); |
| 129 | + ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
| 130 | + ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer; |
| 131 | + ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer; |
| 132 | + ucc_rank_t grank = UCC_TL_TEAM_RANK(team); |
| 133 | + ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); |
| 134 | + ucc_rank_t start = (grank + 1) % gsize; |
| 135 | + size_t nelems = TASK_ARGS(task).src.info.count; |
| 136 | + size_t nreqs = task->alltoall_onesided.tokens; |
| 137 | + size_t npolls = task->alltoall_onesided.npolls; |
| 138 | + int iteration = 0; |
| 139 | + int64_t polls = 0; |
| 140 | + int use_op = task->alltoall_onesided.op; |
| 141 | + ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh; |
| 142 | + ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh; |
| 143 | + uint32_t *posted; |
| 144 | + uint32_t *completed; |
| 145 | + ucc_rank_t peer; |
34 | 146 |
|
35 | 147 | ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); |
36 | | - /* TODO: change when support for library-based work buffers is complete */ |
| 148 | + if (use_op == GET_OP) { |
| 149 | + /* these are not typos */ |
| 150 | + src_memh = TASK_ARGS(task).dst_memh.local_memh; |
| 151 | + dst_memh = TASK_ARGS(task).src_memh.global_memh; |
| 152 | + posted = &task->onesided.get_posted; |
| 153 | + completed = &task->onesided.get_completed; |
| 154 | + } else { |
| 155 | + posted = &task->onesided.put_posted; |
| 156 | + completed = &task->onesided.put_completed; |
| 157 | + } |
37 | 158 | nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); |
38 | | - dest = dest + grank * nelems; |
39 | | - for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) { |
40 | | - UCPCHECK_GOTO(ucc_tl_ucp_put_nb( |
41 | | - (void *)(src + peer * nelems), (void *)dest, nelems, |
42 | | - peer, src_memh, dst_memh, team, task), |
| 159 | + for (peer = start; *posted < gsize; peer = (peer + 1) % gsize) { |
| 160 | + if (use_op == PUT_OP) { |
| 161 | + UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems), |
| 162 | + (void *)PTR_OFFSET(dest, grank * nelems), nelems, peer, src_memh, dst_memh, team, task), |
| 163 | + task, out); |
| 164 | + } else { |
| 165 | + UCPCHECK_GOTO(ucc_tl_ucp_get_nb(PTR_OFFSET(dest, peer * nelems), |
| 166 | + (void *)PTR_OFFSET(src, grank * nelems), nelems, peer, src_memh, dst_memh, team, task), |
43 | 167 | task, out); |
44 | | - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team), |
45 | | - task, out); |
| 168 | + } |
| 169 | + ++iteration; |
| 170 | + |
| 171 | + if ((*posted - *completed) >= nreqs) { |
| 172 | + while (polls++ < npolls) { |
| 173 | + ucp_worker_progress(TASK_CTX(task)->worker.ucp_worker); |
| 174 | + if ((*posted - *completed) < nreqs) { |
| 175 | + break; |
| 176 | + } |
| 177 | + } |
| 178 | + if (polls >= npolls) { |
| 179 | + break; |
| 180 | + } |
| 181 | + } |
46 | 182 | } |
| 183 | + task->alltoall_onesided.iteration = iteration; |
47 | 184 | return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); |
48 | 185 | out: |
49 | 186 | return task->super.status; |
50 | 187 | } |
51 | 188 |
|
52 | | -void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask) |
| 189 | +ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args, |
| 190 | + ucc_base_team_t *team, |
| 191 | + ucc_coll_task_t **task_h) |
53 | 192 | { |
54 | | - ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); |
55 | | - ucc_tl_ucp_team_t *team = TASK_TEAM(task); |
56 | | - ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); |
57 | | - long * pSync = TASK_ARGS(task).global_work_buffer; |
| 193 | + ucc_schedule_t *schedule = NULL; |
| 194 | + ucc_tl_ucp_team_t *tl_team = |
| 195 | + ucc_derived_of(team, ucc_tl_ucp_team_t); |
| 196 | + ucc_base_coll_args_t barrier_coll_args = { |
| 197 | + .team = team->params.team, |
| 198 | + .args.coll_type = UCC_COLL_TYPE_BARRIER, |
| 199 | + }; |
| 200 | + size_t perc_bw = |
| 201 | + UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_percent_bw; |
| 202 | + ucc_on_off_auto_value_t is_get = |
| 203 | + UCC_TL_UCP_TEAM_LIB(tl_team)->cfg.alltoall_onesided_enable_get; |
| 204 | + ucc_coll_task_t *barrier_task; |
| 205 | + ucc_coll_task_t *a2a_task; |
| 206 | + ucc_tl_ucp_task_t *task; |
| 207 | + ucc_status_t status; |
| 208 | + ucc_tl_ucp_schedule_t *tl_schedule; |
| 209 | + size_t nelems; |
| 210 | + size_t rate; |
| 211 | + size_t ratio; |
| 212 | + ucp_ep_h ep; |
| 213 | + ucp_ep_evaluate_perf_param_t param; |
| 214 | + ucp_ep_evaluate_perf_attr_t attr; |
| 215 | + size_t count; |
| 216 | + int64_t npolls; |
| 217 | + ucc_sbgp_t *sbgp; |
58 | 218 |
|
59 | | - if (ucc_tl_ucp_test_onesided(task, gsize) == UCC_INPROGRESS) { |
60 | | - return; |
| 219 | + ALLTOALL_TASK_CHECK(coll_args->args, tl_team); |
| 220 | + if (coll_args->args.mask & UCC_COLL_ARGS_FIELD_FLAGS) { |
| 221 | + if (!(coll_args->args.flags & UCC_COLL_ARGS_FLAG_MEM_MAPPED_BUFFERS)) { |
| 222 | + tl_error(UCC_TL_TEAM_LIB(tl_team), |
| 223 | + "non memory mapped buffers are not supported"); |
| 224 | + status = UCC_ERR_NOT_SUPPORTED; |
| 225 | + goto out; |
| 226 | + } |
| 227 | + } |
| 228 | + if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) { |
| 229 | + coll_args->args.src_memh.global_memh = NULL; |
61 | 230 | } |
| 231 | + if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) { |
| 232 | + coll_args->args.dst_memh.global_memh = NULL; |
| 233 | + } |
| 234 | + status = ucc_tl_ucp_get_schedule(tl_team, coll_args, (ucc_tl_ucp_schedule_t **)&tl_schedule); |
| 235 | + if (ucc_unlikely(UCC_OK != status)) { |
| 236 | + return status; |
| 237 | + } |
| 238 | + schedule = &tl_schedule->super.super; |
| 239 | + ucc_schedule_init(schedule, coll_args, team); |
| 240 | + schedule->super.post = ucc_tl_ucp_alltoall_onesided_sched_start; |
| 241 | + schedule->super.progress = NULL; |
| 242 | + schedule->super.finalize = ucc_tl_ucp_alltoall_onesided_sched_finalize; |
62 | 243 |
|
63 | | - pSync[0] = 0; |
64 | | - task->super.status = UCC_OK; |
| 244 | + sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_NODE); |
| 245 | + task = ucc_tl_ucp_init_task(coll_args, team); |
| 246 | + task->super.post = ucc_tl_ucp_alltoall_onesided_start; |
| 247 | + task->super.progress = ucc_tl_ucp_alltoall_onesided_progress; |
| 248 | + task->super.finalize = ucc_tl_ucp_alltoall_onesided_finalize; |
| 249 | + a2a_task = &task->super; |
| 250 | + |
| 251 | + status = ucc_tl_ucp_coll_init(&barrier_coll_args, team, &barrier_task); |
| 252 | + if (status != UCC_OK) { |
| 253 | + return status; |
| 254 | + } |
| 255 | + if (perc_bw > 100) { |
| 256 | + perc_bw = 100; |
| 257 | + } else if (perc_bw == 0) { |
| 258 | + perc_bw = 1; |
| 259 | + } |
| 260 | + |
| 261 | + nelems = TASK_ARGS(task).src.info.count; |
| 262 | + nelems = nelems / UCC_TL_TEAM_SIZE(tl_team); |
| 263 | + param.field_mask = UCP_EP_PERF_PARAM_FIELD_MESSAGE_SIZE; |
| 264 | + attr.field_mask = UCP_EP_PERF_ATTR_FIELD_ESTIMATED_TIME; |
| 265 | + param.message_size = (1 << 20); |
| 266 | + ucc_tl_ucp_get_ep(tl_team, |
| 267 | + (UCC_TL_TEAM_RANK(tl_team) + 1) % UCC_TL_TEAM_SIZE(tl_team), &ep); |
| 268 | + ucp_ep_evaluate_perf(ep, ¶m, &attr); |
| 269 | + |
| 270 | + rate = (param.message_size / attr.estimated_time) / (param.message_size); |
| 271 | + rate = rate * (double)(perc_bw / 100.0); |
| 272 | + ratio = nelems * sbgp->group_size; |
| 273 | + task->alltoall_onesided.tokens = rate / ratio; |
| 274 | + if (task->alltoall_onesided.tokens < 1) { |
| 275 | + task->alltoall_onesided.tokens = 1; |
| 276 | + } |
| 277 | + npolls = task->n_polls; |
| 278 | + if (is_get == UCC_CONFIG_ON || (is_get == UCC_CONFIG_AUTO && sbgp->group_size >= CONGESTION_THRESHOLD)) { |
| 279 | + int64_t polls; |
| 280 | + count = nelems * ucc_dt_size(TASK_ARGS(task).src.info.datatype); |
| 281 | + polls = count - task->alltoall_onesided.tokens; |
| 282 | + if (polls > task->n_polls) { |
| 283 | + npolls = polls; |
| 284 | + } |
| 285 | + task->alltoall_onesided.op = GET_OP; |
| 286 | + } else { |
| 287 | + task->alltoall_onesided.op = PUT_OP; |
| 288 | + } |
| 289 | + task->alltoall_onesided.npolls = npolls; |
| 290 | + task->alltoall_onesided.iteration = 0; |
| 291 | + |
| 292 | + ucc_schedule_add_task(schedule, a2a_task); |
| 293 | + ucc_task_subscribe_dep(&schedule->super, a2a_task, UCC_EVENT_SCHEDULE_STARTED); |
| 294 | + |
| 295 | + ucc_schedule_add_task(schedule, barrier_task); |
| 296 | + ucc_task_subscribe_dep(a2a_task, barrier_task, |
| 297 | + UCC_EVENT_COMPLETED); |
| 298 | + *task_h = &schedule->super; |
| 299 | +out: |
| 300 | + return status; |
65 | 301 | } |
0 commit comments