diff --git a/src/components/cl/hier/cl_hier.h b/src/components/cl/hier/cl_hier.h index cae41cea322..fb0a9d2efbe 100644 --- a/src/components/cl/hier/cl_hier.h +++ b/src/components/cl/hier/cl_hier.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. @@ -13,6 +13,7 @@ #include "coll_score/ucc_coll_score.h" #include "utils/ucc_mpool.h" #include "schedule/ucc_schedule_pipelined.h" +#include "core/ucc_service_coll.h" #ifdef HAVE_PROFILING_CL_HIER #include "utils/profile/ucc_profile_on.h" @@ -98,13 +99,20 @@ typedef struct ucc_hier_sbgp { int n_tls; } ucc_hier_sbgp_t; +typedef struct ucc_cl_hier_team_create_req { + ucc_team_multiple_req_t *create_req; + ucc_service_coll_req_t *global_status_req; + ucc_status_t local_status; + ucc_status_t global_status; +} ucc_cl_hier_team_create_req_t; + typedef struct ucc_cl_hier_team { - ucc_cl_team_t super; - ucc_team_multiple_req_t *team_create_req; - unsigned n_tl_teams; - ucc_coll_score_t *score; - ucc_hier_sbgp_t sbgps[UCC_HIER_SBGP_LAST]; - ucc_hier_sbgp_type_t top_sbgp; + ucc_cl_team_t super; + ucc_cl_hier_team_create_req_t *team_req; + unsigned n_tl_teams; + ucc_coll_score_t *score; + ucc_hier_sbgp_t sbgps[UCC_HIER_SBGP_LAST]; + ucc_hier_sbgp_type_t top_sbgp; } ucc_cl_hier_team_t; UCC_CLASS_DECLARE(ucc_cl_hier_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index fad67073099..02d62467c91 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -7,7 +7,6 @@ #include "cl_hier.h" #include "utils/ucc_malloc.h" #include "core/ucc_team.h" -#include "core/ucc_service_coll.h" #include "cl_hier_coll.h" #define SBGP_SET(_team, _sbgp, _enable) \ @@ -41,6 +40,8 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, ucc_config_names_array_t *tls; ucc_subset_t subset; struct ucc_team_team_desc *d; + ucc_team_multiple_req_t *team_create_req; + if (!params->team->topo) { cl_info(cl_context->lib, "can't create hier team without topology data"); @@ -53,6 +54,12 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, } UCC_CLASS_CALL_SUPER_INIT(ucc_cl_team_t, &ctx->super, params); + self->team_req = (ucc_cl_hier_team_create_req_t*) + ucc_malloc(sizeof(ucc_cl_hier_team_create_req_t)); + if (!self->team_req) { + return UCC_ERR_NO_MEMORY; + } + memset(self->sbgps, 0, sizeof(self->sbgps)); ucc_cl_hier_enable_sbgps(self); n_sbgp_teams = 0; @@ -88,7 +95,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, } } - status = ucc_team_multiple_req_alloc(&self->team_create_req, n_sbgp_teams); + status = ucc_team_multiple_req_alloc(&team_create_req, n_sbgp_teams); if (UCC_OK != status) { cl_error(cl_context->lib, "failed to allocate team req multiple"); goto err; @@ -102,7 +109,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, hs = &self->sbgps[i]; if (hs->state == UCC_HIER_SBGP_ENABLED) { for (t = 0; t < hs->n_tls; t++) { - d = &self->team_create_req->descs[j]; + d = &team_create_req->descs[j]; d->param.params.mask = UCC_TEAM_PARAM_FIELD_EP_RANGE | UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_TEAM_SIZE | @@ -134,15 +141,18 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, } } - status = ucc_tl_team_create_multiple(self->team_create_req); + + status = ucc_tl_team_create_multiple(team_create_req); if (status < 0) { cl_error(cl_context->lib, "failed to post tl team create (%d)", status); goto err; } + self->team_req->create_req = team_create_req; + self->team_req->global_status_req = NULL; cl_info(cl_context->lib, "posted cl team: %p", self); return UCC_OK; err: - ucc_team_multiple_req_free(self->team_create_req); + ucc_team_multiple_req_free(team_create_req); return status; } @@ -162,15 +172,15 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) int i, j; ucc_hier_sbgp_t *hs; - if (NULL == team->team_create_req) { - status = ucc_team_multiple_req_alloc(&team->team_create_req, + if (NULL == team->team_req->create_req) { + status = ucc_team_multiple_req_alloc(&team->team_req->create_req, team->n_tl_teams); if (UCC_OK != status) { cl_error(ctx->super.super.lib, "failed to allocate team req multiple"); return status; } - team->team_create_req->n_teams = 0; + team->team_req->create_req->n_teams = 0; for (i = 0; i < UCC_HIER_SBGP_LAST; i++) { hs = &team->sbgps[i]; if (hs->state == UCC_HIER_SBGP_ENABLED) { @@ -180,26 +190,27 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) for (j = 0; j < hs->n_tls; j++) { if (hs->tl_teams[j]) { ucc_tl_context_put(hs->tl_ctxs[j]); - team->team_create_req - ->descs[team->team_create_req->n_teams++] + team->team_req->create_req + ->descs[team->team_req->create_req->n_teams++] .team = hs->tl_teams[j]; } } } } } - status = ucc_tl_team_destroy_multiple(team->team_create_req); + status = ucc_tl_team_destroy_multiple(team->team_req->create_req); if (UCC_INPROGRESS == status) { return status; } - for (i = 0; i < team->team_create_req->n_teams; i++) { - if (team->team_create_req->descs[i].status != UCC_OK) { + for (i = 0; i < team->team_req->create_req->n_teams; i++) { + if (team->team_req->create_req->descs[i].status != UCC_OK) { cl_error(ctx->super.super.lib, "tl team destroy failed (%d)", status); - status = team->team_create_req->descs[i].status; + status = team->team_req->create_req->descs[i].status; } } - ucc_team_multiple_req_free(team->team_create_req); + ucc_team_multiple_req_free(team->team_req->create_req); + ucc_free(team->team_req); UCC_CLASS_DELETE_FUNC_NAME(ucc_cl_hier_team_t)(cl_team); return status; } @@ -208,15 +219,23 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team) { ucc_cl_hier_team_t *team = ucc_derived_of(cl_team, ucc_cl_hier_team_t); ucc_cl_hier_context_t *ctx = UCC_CL_HIER_TEAM_CTX(team); - ucc_status_t status; - int i; - ucc_coll_score_t *score, *score_merge; + ucc_status_t status; + ucc_coll_score_t *score, *score_merge; struct ucc_team_team_desc *d; ucc_hier_sbgp_t *hs; + ucc_subset_t subset; + int i; - status = ucc_tl_team_create_multiple(team->team_create_req); - if (status != UCC_OK) { + if (team->team_req->global_status_req) { + /* all team create stages are done, checking global status */ + goto check_global_status; + } + + status = ucc_tl_team_create_multiple(team->team_req->create_req); + if (status == UCC_INPROGRESS) { return status; + } else if (status != UCC_OK) { + goto check_global_status; } team->n_tl_teams = 0; @@ -224,8 +243,8 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team) /* TL teams are created: get scores and merge them to produce * score map for each sbgp */ - for (i = 0; i < team->team_create_req->n_teams; i++) { - d = &team->team_create_req->descs[i]; + for (i = 0; i < team->team_req->create_req->n_teams; i++) { + d = &team->team_req->create_req->descs[i]; ucc_hier_sbgp_type_t st = (ucc_hier_sbgp_type_t)d->args[0]; int tl = (int)d->args[1]; @@ -288,8 +307,8 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team) } } } - ucc_team_multiple_req_free(team->team_create_req); - team->team_create_req = NULL; + ucc_team_multiple_req_free(team->team_req->create_req); + team->team_req->create_req = NULL; if (SBGP_EXISTS(team, NODE_LEADERS)) { team->top_sbgp = UCC_HIER_SBGP_NODE_LEADERS; @@ -298,7 +317,32 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team) team->top_sbgp = UCC_HIER_SBGP_NODE; } - return status; +check_global_status: + if (!team->team_req->global_status_req) { + subset.map.type = UCC_EP_MAP_FULL; + subset.map.ep_num = team->super.super.params.size; + subset.myrank = team->super.super.params.rank; + team->team_req->local_status = status; + status = ucc_service_allreduce(team->super.super.params.team, + &team->team_req->local_status, + &team->team_req->global_status, + UCC_DT_INT32, 1, UCC_OP_MIN, subset, + &team->team_req->global_status_req); + if (status != UCC_OK) { + cl_error(ctx->super.super.lib, "failed to start service allreduce"); + return status; + } + } + + status = ucc_service_coll_test(team->team_req->global_status_req); + if (status == UCC_INPROGRESS) { + return status; + } + ucc_service_coll_finalize(team->team_req->global_status_req); + if (status != UCC_OK) { + return status; + } + return team->team_req->global_status; } ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team, diff --git a/src/core/ucc_team.c b/src/core/ucc_team.c index 354e8777191..1d7919076ba 100644 --- a/src/core/ucc_team.c +++ b/src/core/ucc_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Meta Platforms, Inc. and affiliates. 2022. * * See file LICENSE for terms. @@ -282,7 +282,8 @@ static ucc_status_t ucc_team_create_cls(ucc_context_t *context, status = cl_iface->team.create_test(b_team); if (status < 0) { team->n_cl_teams--; - ucc_info("failed to create CL %s team", cl_iface->super.name); + ucc_info("failed to create CL %s team, team_id %d", + cl_iface->super.name, team->id); cl_iface->team.destroy(b_team); } else if (status == UCC_INPROGRESS) { return status; @@ -294,12 +295,14 @@ static ucc_status_t ucc_team_create_cls(ucc_context_t *context, status = cl_iface->team.create_post(&context->cl_ctx[i]->super, &team->bp, &b_team); if (status != UCC_OK) { - ucc_info("failed to create CL %s team", cl_iface->super.name); + ucc_info("failed to create CL %s team, team_id %d", + cl_iface->super.name, team->id); continue; } status = cl_iface->team.create_test(b_team); if (status < 0) { - ucc_info("failed to create CL %s team", cl_iface->super.name); + ucc_info("failed to create CL %s team, team_id %d", + cl_iface->super.name, team->id); cl_iface->team.destroy(b_team); continue; }