diff --git a/setup.py b/setup.py index dd1015d6..5243babc 100644 --- a/setup.py +++ b/setup.py @@ -26,12 +26,16 @@ HAS_SM80 = False HAS_SM86 = False +HAS_SM87 = False HAS_SM89 = False HAS_SM90 = False +HAS_SM100 = False +HAS_SM101 = False +HAS_SM110 = False HAS_SM120 = False # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"} +SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "11.0", "12.0"} # Compiler flags. CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] @@ -103,12 +107,24 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: elif capability.startswith("8.6"): HAS_SM86 = True num = "86" + elif capability.startswith("8.7"): + HAS_SM87 = True + num = "87" elif capability.startswith("8.9"): HAS_SM89 = True num = "89" elif capability.startswith("9.0"): HAS_SM90 = True num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction. + elif capability.startswith("10.0"): + HAS_SM100 = True + num = "100" + elif capability.startswith("10.1"): + HAS_SM101 = True + num = "101" + elif capability.startswith("11.0"): + HAS_SM110 = True + num = "110" elif capability.startswith("12.0"): HAS_SM120 = True num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. @@ -185,4 +201,4 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: python_requires='>=3.9', ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, -) \ No newline at end of file +)