@@ -38,13 +38,17 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
38
38
if (cu_ctx == NULL || cu_st != CUDA_SUCCESS ) {
39
39
tl_debug (tl_lib ,
40
40
"cannot create CUDA TL team without active CUDA context" );
41
- return UCC_ERR_NO_RESOURCE ;
41
+ team -> device_id = TL_CUDA_DEVICE_INVALID ;
42
+ team -> state = TL_CUDA_STATE_ERROR ;
43
+ goto exchnage_rank_ids ;
42
44
}
43
45
44
46
status = CUDA_FUNC (cudaGetDevice (& team -> device ));
45
47
if (status != UCC_OK ) {
46
48
tl_debug (tl_lib , "failed to get current device id" );
47
- return status ;
49
+ team -> device_id = TL_CUDA_DEVICE_INVALID ;
50
+ team -> state = TL_CUDA_STATE_ERROR ;
51
+ goto exchnage_rank_ids ;
48
52
}
49
53
50
54
status = ucc_tl_cuda_topo_get_pci_id (team -> device , & team -> device_id );
@@ -88,6 +92,7 @@ ucc_status_t ucc_tl_cuda_comm_init_post(ucc_tl_cuda_team_t *team)
88
92
goto free_scratch ;
89
93
}
90
94
95
+ exchnage_rank_ids :
91
96
rank_id -> pci_id = team -> device_id ;
92
97
status = team -> oob .allgather (rank_id , team -> ids , rank_id_size ,
93
98
team -> oob .coll_info , & team -> oob_req );
@@ -127,6 +132,17 @@ ucc_status_t ucc_tl_cuda_comm_init_test(ucc_tl_cuda_team_t *team)
127
132
return status ;
128
133
}
129
134
team -> oob .req_free (team -> oob_req );
135
+ /* check all ranks have valid CUDA device set */
136
+ for (r = 0 ; r < tsize ; r ++ ) {
137
+ rank_id = GET_RANK_ID (team -> ids , r , max_concurrent );
138
+ if (ucc_tl_cuda_topo_device_id_equal (& rank_id -> pci_id ,
139
+ & TL_CUDA_DEVICE_INVALID )) {
140
+ tl_debug (tl_lib , "rank %d device is invalid, team can't be created" ,
141
+ r );
142
+ team -> state = TL_CUDA_STATE_ERROR ;
143
+ return UCC_ERR_NO_RESOURCE ;
144
+ }
145
+ }
130
146
131
147
status = ucc_tl_cuda_team_topo_create (& team -> super , & team -> topo );
132
148
if (status != UCC_OK ) {
@@ -234,6 +250,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_team_t, ucc_base_context_t *tl_context,
234
250
self -> stream = NULL ;
235
251
self -> topo = NULL ;
236
252
self -> scratch .loc = NULL ;
253
+ self -> device = -1 ;
237
254
238
255
if (!ucc_team_map_is_single_node (params -> team , params -> map )) {
239
256
tl_debug (tl_context -> lib , "multinode team is not supported" );
0 commit comments