Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/qattn/attn_cuda_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <torch/extension.h>

bool is_available();
torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(
torch::Tensor query,
torch::Tensor key,
Expand All @@ -41,4 +42,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
int is_causal,
int qk_quant_gran,
float sm_scale,
int return_lse);
int return_lse);
3 changes: 2 additions & 1 deletion csrc/qattn/pybind_sm90.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("qk_int8_sv_f8_accum_f32_attn_inst_buf", &qk_int8_sv_f8_accum_f32_attn_inst_buf);
m.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf);
}
m.def("is_available", &is_available);
}
52 changes: 50 additions & 2 deletions csrc/qattn/qk_int_sv_f8_cuda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,54 @@

#include "attn_utils.cuh"

struct _CudaDriverApiLoader {
static bool initialized;
static bool init() {
#define LOAD(name) name = load_api<name##_t>((const char *)#name)
LOAD(cuTensorMapEncodeTiled);
#undef LOAD
return true;
}

#define DECL_API(name) \
using name##_t = decltype(&::name); \
static name##_t name
DECL_API(cuTensorMapEncodeTiled);
#undef DECL_API

template <typename FuncType>
static FuncType load_api(const char *name) {
void *func = nullptr;
cudaDriverEntryPointQueryResult qres = {};
auto ret =
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050)
cudaGetDriverEntryPointByVersion(name, &func, CUDA_VERSION, cudaEnableDefault, &qres);
#else
cudaGetDriverEntryPoint(name, &func, cudaEnableDefault, &qres);
#endif
if (ret == cudaSuccess && qres == cudaDriverEntryPointSuccess && func != nullptr) {
return reinterpret_cast<FuncType>(func);
} else {
return nullptr;
}
}

static bool is_available() {
return cuTensorMapEncodeTiled != nullptr;
}
};

bool _CudaDriverApiLoader::initialized = _CudaDriverApiLoader::init();

bool is_available() {
return _CudaDriverApiLoader::is_available();
}

#define DECL_API(name) _CudaDriverApiLoader::name##_t _CudaDriverApiLoader::name
DECL_API(cuTensorMapEncodeTiled);
#undef DECL_API


template <int BlockMajorSize, int BlockMinorSize, bool swizzle=true, CUtensorMapL2promotion_enum promotion_mode=CU_TENSOR_MAP_L2_PROMOTION_NONE, typename T>
CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) {
constexpr int smem_stride = BlockMinorSize * sizeof(T);
Expand All @@ -39,7 +87,7 @@ CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, in
uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1};
uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};

CUresult result = cuTensorMapEncodeTiled(
CUresult result = _CudaDriverApiLoader::cuTensorMapEncodeTiled(
&tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape,
gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE,
(swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B,
Expand Down Expand Up @@ -913,4 +961,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
});

return lse;
}
}
4 changes: 2 additions & 2 deletions sageattention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

try:
from . import _qattn_sm90
SM90_ENABLED = True
SM90_ENABLED = _qattn_sm90.is_available()
except:
SM90_ENABLED = False

Expand Down Expand Up @@ -952,4 +952,4 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
return o
13 changes: 9 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version

# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
compute_capabilities = set()
compute_capabilities = {arch for arch in os.getenv("TORCH_CUDA_ARCH_LIST", "").split(";")
if arch in SUPPORTED_ARCHS or (arch.endswith("+PTX") and arch[:-len("+PTX")] in SUPPORTED_ARCHS)}
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
Expand Down Expand Up @@ -95,6 +96,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
raise RuntimeError(
"CUDA 12.8 or higher is required for compute capability 12.0.")

NVCC_FLAGS_NO_CODE = NVCC_FLAGS.copy()
NVCC_FLAGS_CODE = dict()
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
if capability.startswith("8.0"):
Expand All @@ -113,8 +116,10 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
HAS_SM120 = True
num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions.
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
NVCC_FLAGS_CODE[num] = NVCC_FLAGS[-2:]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
NVCC_FLAGS_CODE[num] += NVCC_FLAGS[-2:]

ext_modules = []

Expand Down Expand Up @@ -148,7 +153,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
"nvcc": NVCC_FLAGS_NO_CODE + NVCC_FLAGS_CODE.get("89", []) + NVCC_FLAGS_CODE.get("120", []),
},
)
ext_modules.append(qattn_extension)
Expand All @@ -162,9 +167,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
"nvcc": NVCC_FLAGS_NO_CODE + NVCC_FLAGS_CODE.get("90a", []),
},
extra_link_args=['-lcuda'],
# extra_link_args=['-lcuda'], need not to explicit link against cuda
)
ext_modules.append(qattn_extension)

Expand Down