Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,38 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version

# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
# Determine compute capabilities with priority order:
# 1. TORCH_CUDA_ARCH_LIST environment variable
# 2. Auto-detect from current machine GPUs
# 3. Use all SUPPORTED_ARCHS as fallback
compute_capabilities = set()
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")

# First, try to read from TORCH_CUDA_ARCH_LIST environment variable
torch_cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
if torch_cuda_arch_list:
print(f"Using TORCH_CUDA_ARCH_LIST: {torch_cuda_arch_list}")
# Parse TORCH_CUDA_ARCH_LIST (format like "8.0;8.6;8.9;9.0")
for arch in torch_cuda_arch_list.replace(";", " ").split():
arch = arch.strip()
if arch in SUPPORTED_ARCHS:
compute_capabilities.add(arch)
else:
warnings.warn(f"Unsupported architecture {arch} in TORCH_CUDA_ARCH_LIST, skipping")
else:
# Second, try to auto-detect from current machine GPUs
device_count = torch.cuda.device_count()
if device_count > 0:
print(f"Auto-detecting from {device_count} GPU(s) on current machine")
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")
else:
# Third, use all SUPPORTED_ARCHS as fallback
print("No GPUs detected, using all supported architectures as fallback")
compute_capabilities = SUPPORTED_ARCHS.copy()

nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
Expand Down