From 3f09de6041d0ea82125cca7d85949c9ead98b8ab Mon Sep 17 00:00:00 2001 From: Yemin Shi Date: Tue, 24 Jun 2025 04:13:51 +0000 Subject: [PATCH] support TORCH_CUDA_ARCH_LIST --- setup.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index dd1015d6..ea2bffa2 100644 --- a/setup.py +++ b/setup.py @@ -66,15 +66,38 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version -# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. +# Determine compute capabilities with priority order: +# 1. TORCH_CUDA_ARCH_LIST environment variable +# 2. Auto-detect from current machine GPUs +# 3. Use all SUPPORTED_ARCHS as fallback compute_capabilities = set() -device_count = torch.cuda.device_count() -for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 8: - warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") - continue - compute_capabilities.add(f"{major}.{minor}") + +# First, try to read from TORCH_CUDA_ARCH_LIST environment variable +torch_cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") +if torch_cuda_arch_list: + print(f"Using TORCH_CUDA_ARCH_LIST: {torch_cuda_arch_list}") + # Parse TORCH_CUDA_ARCH_LIST (format like "8.0;8.6;8.9;9.0") + for arch in torch_cuda_arch_list.replace(";", " ").split(): + arch = arch.strip() + if arch in SUPPORTED_ARCHS: + compute_capabilities.add(arch) + else: + warnings.warn(f"Unsupported architecture {arch} in TORCH_CUDA_ARCH_LIST, skipping") +else: + # Second, try to auto-detect from current machine GPUs + device_count = torch.cuda.device_count() + if device_count > 0: + print(f"Auto-detecting from {device_count} GPU(s) on current machine") + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") + continue + compute_capabilities.add(f"{major}.{minor}") + else: + # Third, use all SUPPORTED_ARCHS as fallback + print("No GPUs detected, using all supported architectures as fallback") + compute_capabilities = SUPPORTED_ARCHS.copy() nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: