|  | 
| 1 | 1 | /** | 
| 2 |  | - * Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
|  | 2 | + * Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
| 3 | 3 |  * | 
| 4 | 4 |  * See file LICENSE for terms. | 
| 5 | 5 |  */ | 
| @@ -43,6 +43,11 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, | 
| 43 | 43 |     ucc_config_names_array_t  *tls; | 
| 44 | 44 |     ucc_subset_t               subset; | 
| 45 | 45 |     struct ucc_team_team_desc *d; | 
|  | 46 | +    ucc_tl_context_t          *tl_ctx; | 
|  | 47 | +    ucc_tl_lib_t              *tl_lib; | 
|  | 48 | +    ucc_base_lib_attr_t        attr; | 
|  | 49 | + | 
|  | 50 | + | 
| 46 | 51 |     if (!params->team->topo) { | 
| 47 | 52 |         cl_debug(cl_context->lib, | 
| 48 | 53 |                 "can't create hier team without topology data"); | 
| @@ -74,18 +79,51 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context, | 
| 74 | 79 |             hs->n_tls = 0; | 
| 75 | 80 |             tls       = &lib->cfg.sbgp_tls[i].array; | 
| 76 | 81 |             for (j = 0; j < tls->count; j++) { | 
|  | 82 | +                if (hs->n_tls == CL_HIER_MAX_SBGP_TLS) { | 
|  | 83 | +                    cl_debug(cl_context->lib, | 
|  | 84 | +                             "skipping tl context %s for %s sbgp: " | 
|  | 85 | +                             "max number of TLs per SBGP is reached", | 
|  | 86 | +                             tls->names[j], ucc_sbgp_str(hs->sbgp_type)); | 
|  | 87 | +                    continue; | 
|  | 88 | +                } | 
| 77 | 89 |                 status = ucc_tl_context_get(ctx->super.super.ucc_context, | 
| 78 | 90 |                                             tls->names[j], | 
| 79 | 91 |                                             &hs->tl_ctxs[hs->n_tls]); | 
| 80 | 92 |                 if (UCC_OK != status) { | 
| 81 | 93 |                     cl_debug(cl_context->lib, | 
| 82 | 94 |                              "tl context %s is not available for sbgp %s", | 
| 83 | 95 |                              tls->names[j], ucc_sbgp_str(hs->sbgp_type)); | 
| 84 |  | -                } else { | 
| 85 |  | -                    hs->n_tls++; | 
| 86 |  | -                    n_sbgp_teams++; | 
| 87 |  | -                    ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); | 
|  | 96 | +                    continue; | 
| 88 | 97 |                 } | 
|  | 98 | +                attr.mask = UCC_BASE_LIB_ATTR_FIELD_MIN_TEAM_SIZE | | 
|  | 99 | +                            UCC_BASE_LIB_ATTR_FIELD_MAX_TEAM_SIZE; | 
|  | 100 | +                tl_ctx = hs->tl_ctxs[hs->n_tls]; | 
|  | 101 | +                tl_lib = ucc_derived_of(tl_ctx->super.lib, ucc_tl_lib_t); | 
|  | 102 | +                status = tl_lib->iface->lib.get_attr(tl_ctx->super.lib, | 
|  | 103 | +                                                     &attr); | 
|  | 104 | +                if (status != UCC_OK) { | 
|  | 105 | +                    cl_debug(cl_context->lib, | 
|  | 106 | +                             "failed to get attributes for tl context %s", | 
|  | 107 | +                             tls->names[j]); | 
|  | 108 | +                    ucc_tl_context_put(tl_ctx); | 
|  | 109 | +                    continue; | 
|  | 110 | +                } | 
|  | 111 | + | 
|  | 112 | +                if (hs->sbgp->group_size < attr.min_team_size || | 
|  | 113 | +                    hs->sbgp->group_size > attr.max_team_size) { | 
|  | 114 | +                    cl_debug(cl_context->lib, | 
|  | 115 | +                            "tl context %s is not suitable for sbgp %s" | 
|  | 116 | +                            "sbgp: sbgp size %d is not in range [%d; %d]", | 
|  | 117 | +                            tls->names[j], ucc_sbgp_str(hs->sbgp_type), | 
|  | 118 | +                            hs->sbgp->group_size, | 
|  | 119 | +                            attr.min_team_size, attr.max_team_size); | 
|  | 120 | +                    ucc_tl_context_put(tl_ctx); | 
|  | 121 | +                    continue; | 
|  | 122 | +                } | 
|  | 123 | + | 
|  | 124 | +                hs->n_tls++; | 
|  | 125 | +                n_sbgp_teams++; | 
|  | 126 | +                ucc_assert(hs->n_tls <= CL_HIER_MAX_SBGP_TLS); | 
| 89 | 127 |             } | 
| 90 | 128 |         } | 
| 91 | 129 |     } | 
|  | 
0 commit comments