diff --git a/infinity/utils/dist.py b/infinity/utils/dist.py index 5c51f4f..d9fe594 100644 --- a/infinity/utils/dist.py +++ b/infinity/utils/dist.py @@ -34,6 +34,7 @@ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() local_rank = global_rank % num_gpus torch.cuda.set_device(local_rank) + print(f"global_rank:{global_rank} local_rank:{local_rank} num_gpus:{num_gpus}") # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 """ @@ -42,7 +43,9 @@ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout print(f'[dist initialize] mp method={method}') mp.set_start_method(method) """ - tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) + tdist.init_process_group(backend=backend, + device_id=torch.device(f'cuda:{local_rank}'), + timeout=datetime.timedelta(seconds=timeout_minutes * 60)) global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill __local_rank = local_rank