diff --git a/csrc/qattn/attn_cuda_sm90.h b/csrc/qattn/attn_cuda_sm90.h index 822ab7f4..d0c787c9 100644 --- a/csrc/qattn/attn_cuda_sm90.h +++ b/csrc/qattn/attn_cuda_sm90.h @@ -16,6 +16,7 @@ #include +bool is_available(); torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( torch::Tensor query, torch::Tensor key, @@ -41,4 +42,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( int is_causal, int qk_quant_gran, float sm_scale, - int return_lse); \ No newline at end of file + int return_lse); diff --git a/csrc/qattn/pybind_sm90.cpp b/csrc/qattn/pybind_sm90.cpp index 8900b641..c9113fb2 100644 --- a/csrc/qattn/pybind_sm90.cpp +++ b/csrc/qattn/pybind_sm90.cpp @@ -22,4 +22,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("qk_int8_sv_f8_accum_f32_attn_inst_buf", &qk_int8_sv_f8_accum_f32_attn_inst_buf); m.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf); -} \ No newline at end of file + m.def("is_available", &is_available); +} diff --git a/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu b/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu index e9e5ccf5..2f48dc95 100644 --- a/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu +++ b/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu @@ -26,6 +26,54 @@ #include "attn_utils.cuh" +struct _CudaDriverApiLoader { + static bool initialized; + static bool init() { + #define LOAD(name) name = load_api((const char *)#name) + LOAD(cuTensorMapEncodeTiled); + #undef LOAD + return true; + } + +#define DECL_API(name) \ + using name##_t = decltype(&::name); \ + static name##_t name + DECL_API(cuTensorMapEncodeTiled); +#undef DECL_API + + template + static FuncType load_api(const char *name) { + void *func = nullptr; + cudaDriverEntryPointQueryResult qres = {}; + auto ret = + #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050) + cudaGetDriverEntryPointByVersion(name, &func, CUDA_VERSION, cudaEnableDefault, &qres); + #else + cudaGetDriverEntryPoint(name, &func, cudaEnableDefault, &qres); + #endif + if (ret == cudaSuccess && qres == cudaDriverEntryPointSuccess && func != nullptr) { + return reinterpret_cast(func); + } else { + return nullptr; + } + } + + static bool is_available() { + return cuTensorMapEncodeTiled != nullptr; + } +}; + +bool _CudaDriverApiLoader::initialized = _CudaDriverApiLoader::init(); + +bool is_available() { + return _CudaDriverApiLoader::is_available(); +} + +#define DECL_API(name) _CudaDriverApiLoader::name##_t _CudaDriverApiLoader::name +DECL_API(cuTensorMapEncodeTiled); +#undef DECL_API + + template CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { constexpr int smem_stride = BlockMinorSize * sizeof(T); @@ -39,7 +87,7 @@ CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, in uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; - CUresult result = cuTensorMapEncodeTiled( + CUresult result = _CudaDriverApiLoader::cuTensorMapEncodeTiled( &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, @@ -913,4 +961,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( }); return lse; -} \ No newline at end of file +} diff --git a/sageattention/core.py b/sageattention/core.py index c8829e45..9250e9bb 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -40,7 +40,7 @@ try: from . import _qattn_sm90 - SM90_ENABLED = True + SM90_ENABLED = _qattn_sm90.is_available() except: SM90_ENABLED = False @@ -952,4 +952,4 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: - return o \ No newline at end of file + return o diff --git a/setup.py b/setup.py index 5e4779dd..da7df465 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: return nvcc_cuda_version # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. -compute_capabilities = set() +compute_capabilities = {arch for arch in os.getenv("TORCH_CUDA_ARCH_LIST", "").split(";") + if arch in SUPPORTED_ARCHS or (arch.endswith("+PTX") and arch[:-len("+PTX")] in SUPPORTED_ARCHS)} device_count = torch.cuda.device_count() for i in range(device_count): major, minor = torch.cuda.get_device_capability(i) @@ -95,6 +96,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: raise RuntimeError( "CUDA 12.8 or higher is required for compute capability 12.0.") +NVCC_FLAGS_NO_CODE = NVCC_FLAGS.copy() +NVCC_FLAGS_CODE = dict() # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: if capability.startswith("8.0"): @@ -113,8 +116,10 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: HAS_SM120 = True num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + NVCC_FLAGS_CODE[num] = NVCC_FLAGS[-2:] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + NVCC_FLAGS_CODE[num] += NVCC_FLAGS[-2:] ext_modules = [] @@ -148,7 +153,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ], extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, + "nvcc": NVCC_FLAGS_NO_CODE + NVCC_FLAGS_CODE.get("89", []) + NVCC_FLAGS_CODE.get("120", []), }, ) ext_modules.append(qattn_extension) @@ -162,9 +167,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ], extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, + "nvcc": NVCC_FLAGS_NO_CODE + NVCC_FLAGS_CODE.get("90a", []), }, - extra_link_args=['-lcuda'], + # extra_link_args=['-lcuda'], need not to explicit link against cuda ) ext_modules.append(qattn_extension)