Skip to content

Commit fa2c434

Browse files
CL/HIER: check number of TLs per SBGP
1 parent e66574d commit fa2c434

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

src/components/cl/hier/cl_hier_team.c

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
*
44
* See file LICENSE for terms.
55
*/
@@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
4343
ucc_config_names_array_t *tls;
4444
ucc_subset_t subset;
4545
struct ucc_team_team_desc *d;
46+
ucc_tl_context_t *tl_ctx;
47+
ucc_tl_lib_t *tl_lib;
48+
ucc_base_lib_attr_t attr;
49+
50+
4651
if (!params->team->topo) {
4752
cl_debug(cl_context->lib,
4853
"can't create hier team without topology data");
@@ -74,18 +79,51 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
7479
hs->n_tls = 0;
7580
tls = &lib->cfg.sbgp_tls[i].array;
7681
for (j = 0; j < tls->count; j++) {
82+
if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) {
83+
cl_debug(cl_context->lib,
84+
"skipping tl context %s for %s sbgp: "
85+
"max number of TLs per SBGP is reached",
86+
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
87+
continue;
88+
}
7789
status = ucc_tl_context_get(ctx->super.super.ucc_context,
7890
tls->names[j],
7991
&hs->tl_ctxs[hs->n_tls]);
8092
if (UCC_OK != status) {
8193
cl_debug(cl_context->lib,
8294
"tl context %s is not available for sbgp %s",
8395
tls->names[j], ucc_sbgp_str(hs->sbgp_type));
84-
} else {
85-
hs->n_tls++;
86-
n_sbgp_teams++;
87-
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
96+
continue;
8897
}
98+
attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE |
99+
UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE;
100+
tl_ctx = hs->tl_ctxs[hs->n_tls];
101+
tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t);
102+
status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib,
103+
&attr);
104+
if (status != UCC_OK) {
105+
cl_debug(cl_context->lib,
106+
"failed to get attributes for tl context %s",
107+
tls->names[j]);
108+
ucc_tl_context_put(tl_ctx);
109+
continue;
110+
}
111+
112+
if (hs->sbgp->group_size < attr.min_team_size ||
113+
hs->sbgp->group_size > attr.max_team_size) {
114+
cl_debug(cl_context->lib,
115+
"tl context %s is not suitable for sbgp %s"
116+
"sbgp: sbgp size %d is not in range [%d; %d]",
117+
tls->names[j], ucc_sbgp_str(hs->sbgp_type),
118+
hs->sbgp->group_size,
119+
attr.min_team_size, attr.max_team_size);
120+
ucc_tl_context_put(tl_ctx);
121+
continue;
122+
}
123+
124+
hs->n_tls++;
125+
n_sbgp_teams++;
126+
ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS);
89127
}
90128
}
91129
}

src/components/ec/cuda/ec_cuda.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,9 @@ ucc_status_t ucc_ec_cuda_get_resources(ucc_ec_cuda_resources_t **resources)
282282
#else
283283
status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id));
284284
if (ucc_unlikely(status != UCC_OK)) {
285-
ec_error(&ucc_ec_cuda.super, "failed to get currect CUDA context ID");
285+
/* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/
286+
cu_ctx_id = 0x12345;
287+
ec_debug(&ucc_ec_cuda.super, "failed to get currect CUDA context ID");
286288
}
287289
#endif
288290

src/components/mc/cuda/mc_cuda.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,9 @@ ucc_status_t ucc_mc_cuda_get_resources(ucc_mc_cuda_resources_t **resources)
368368
#else
369369
status = CUDADRV_FUNC(cuCtxGetId(cu_ctx, &cu_ctx_id));
370370
if (ucc_unlikely(status != UCC_OK)) {
371-
mc_error(&ucc_mc_cuda.super, "failed to get currect CUDA context ID");
371+
/* worakround for pytorch, progress thread doesn't have cuda context for GPU 0*/
372+
cu_ctx_id = 0x12345;
373+
mc_debug(&ucc_mc_cuda.super, "failed to get currect CUDA context ID");
372374
}
373375
#endif
374376

test/mpi/main.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* Copyright (c) Advanced Micro Devices, Inc. 2023. ALL RIGHTS RESERVED.
44
*
55
* See file LICENSE for terms.
@@ -574,15 +574,17 @@ void ProcessArgs(int argc, char** argv)
574574

575575
int main(int argc, char *argv[])
576576
{
577-
int failed = 0;
578-
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4] = {0};
577+
int failed = 0;
578+
int total_done_skipped_failed[ucc_ilog2(UCC_COLL_TYPE_LAST) + 1][4];
579579
std::chrono::steady_clock::time_point begin;
580580
int size, required, provided, completed, rank;
581581
UccTestMpi *test;
582582
MPI_Request req;
583583
std::string err;
584584

585585
begin = std::chrono::steady_clock::now();
586+
memset(total_done_skipped_failed, 0,
587+
sizeof(total_done_skipped_failed));
586588
try {
587589
ProcessArgs(argc, argv);
588590
} catch (const std::string &s) {

0 commit comments

Comments
 (0)